Welcome toVigges Developer Community-Open, Learning,Share
Welcome To Ask or Share your Answers For Others


0 votes
in Technique[技术] by (71.8m points)

python - variational auto encoder loss goes down but does not reconstruct input. out of debugging ideas. works on mnist but not on other data

My variational autoencoder seems to work for MNIST, but fails on slightly "harder" data.
By "fails" I mean there are at least two apparent problems:

  1. Very poor reconstruction, for example sample reconstructions from the last epoch on validation set enter image description here enter image description here enter image description here without any regularization at all.
    The last reported losses from console are val_loss=9.57e-5, train_loss=9.83e-5 which I thought would imply exact reconstructions.
  2. validation loss is low (which does not seem to reflect the reconstruction), and always lower than training loss which is very suspicious. losses losses2

For MNIST everything looks fine (with less layers!).

mnist recon

I will give as much nformation as I can, since I am not sure what I should provide to help anyone help me.

Firstly, here is the full code
You will notice loss calculation and logging is very simple and straight forward and I can't seem to find what's wrong.

import torch
from torch import nn
import torch.nn.functional as F
from typing import List, Optional, Any
from pytorch_lightning.core.lightning import LightningModule
from Testing.Research.config.ConfigProvider import ConfigProvider
from pytorch_lightning import Trainer, seed_everything
from torch import optim
import os
from pytorch_lightning.loggers import TensorBoardLogger
# import tfmpl
import matplotlib.pyplot as plt
import matplotlib
from Testing.Research.data_modules.MyDataModule import MyDataModule
from Testing.Research.data_modules.MNISTDataModule import MNISTDataModule
from Testing.Research.data_modules.CaseDataModule import CaseDataModule
import torchvision
from Testing.Research.config.paths import tb_logs_folder
from Testing.Research.config.paths import vae_checkpoints_path
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

class VAEFC(LightningModule):
    # see https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73
    # for possible upgrades, see https://arxiv.org/pdf/1602.02282.pdf
    # https://stats.stackexchange.com/questions/332179/how-to-weight-kld-loss-vs-reconstruction-loss-in-variational
    # -auto-encoder
    def __init__(self, encoder_layer_sizes: List, decoder_layer_sizes: List, config):
        super(VAEFC, self).__init__()
        self._config = config
        self.logger: Optional[TensorBoardLogger] = None

        assert len(encoder_layer_sizes) >= 3, "must have at least 3 layers (2 hidden)"
        # encoder layers
        self._encoder_layers = nn.ModuleList()
        for i in range(1, len(encoder_layer_sizes) - 1):
            enc_layer = nn.Linear(encoder_layer_sizes[i - 1], encoder_layer_sizes[i])

        # predict mean and covariance vectors
        self._mean_layer = nn.Linear(encoder_layer_sizes[
                                         len(encoder_layer_sizes) - 2],
                                     encoder_layer_sizes[len(encoder_layer_sizes) - 1])
        self._logvar_layer = nn.Linear(encoder_layer_sizes[
                                           len(encoder_layer_sizes) - 2],
                                       encoder_layer_sizes[len(encoder_layer_sizes) - 1])

        # decoder layers
        self._decoder_layers = nn.ModuleList()
        for i in range(1, len(decoder_layer_sizes)):
            dec_layer = nn.Linear(decoder_layer_sizes[i - 1], decoder_layer_sizes[i])

        self._recon_function = nn.MSELoss(reduction='mean')
        self._last_val_batch = {}

    def _encode(self, x):
        for i in range(len(self._encoder_layers)):
            layer = self._encoder_layers[i]
            x = F.relu(layer(x))

        mean_output = self._mean_layer(x)
        logvar_output = self._logvar_layer(x)
        return mean_output, logvar_output

    def _reparametrize(self, mu, logvar):
        if not self.training:
            return mu
        std = logvar.mul(0.5).exp_()
        if std.is_cuda:
            eps = torch.FloatTensor(std.size()).cuda().normal_()
            eps = torch.FloatTensor(std.size()).normal_()
        reparameterized = eps.mul(std).add_(mu)
        return reparameterized

    def _decode(self, z):
        for i in range(len(self._decoder_layers) - 1):
            layer = self._decoder_layers[i]
            z = F.relu((layer(z)))

        decoded = self._decoder_layers[len(self._decoder_layers) - 1](z)
        # decoded = F.sigmoid(self._decoder_layers[len(self._decoder_layers)-1](z))
        return decoded

    def _loss_function(self, recon_x, x, mu, logvar, reconstruction_function):
        recon_x: generating images
        x: origin images
        mu: latent mean
        logvar: latent log variance
        binary_cross_entropy = reconstruction_function(recon_x, x)  # mse loss TODO see if mse or cross entropy
        # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        kld_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
        kld = torch.sum(kld_element).mul_(-0.5)
        # KL divergence Kullback–Leibler divergence, regularization term for VAE
        # It is a measure of how different two probability distributions are different from each other.
        # We are trying to force the distributions closer while keeping the reconstruction loss low.
        # see https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73

        # read on weighting the regularization term here:
        # https://stats.stackexchange.com/questions/332179/how-to-weight-kld-loss-vs-reconstruction-loss-in-variational
        # -auto-encoder
        return binary_cross_entropy + kld * self._config.regularization_factor

    def _parse_batch_by_dataset(self, batch, batch_index):
        if self._config.dataset == "toy":
            (orig_batch, noisy_batch), label_batch = batch
            # TODO put in the noise here and not in the dataset?
        elif self._config.dataset == "mnist":
            orig_batch, label_batch = batch
            orig_batch = orig_batch.reshape(-1, 28 * 28)
            noisy_batch = orig_batch
        elif self._config.dataset == "case":
            orig_batch, label_batch = batch

            orig_batch = orig_batch.float().reshape(
                    len(self._config.case.feature_list) * self._config.case.frames_per_pd_sample
            noisy_batch = orig_batch
            raise ValueError("invalid dataset")
        noisy_batch = noisy_batch.view(noisy_batch.size(0), -1)

        return orig_batch, noisy_batch, label_batch

    def training_step(self, batch, batch_idx):
        orig_batch, noisy_batch, label_batch = self._parse_batch_by_dataset(batch, batch_idx)

        recon_batch, mu, logvar = self.forward(noisy_batch)

        loss = self._loss_function(
                orig_batch, mu, logvar,
        # self.logger.experiment.add_scalars("losses", {"train_loss": loss})
        tb = self.logger.experiment
        tb.add_scalars("losses", {"train_loss": loss}, global_step=self.current_epoch)
        # self.logger.experiment.add_scalar("train_loss", loss, self.current_epoch)
        if batch_idx == len(self.train_dataloader()) - 2:
            # https://pytorch.org/docs/stable/_modules/torch/utils/tensorboard/writer.html#SummaryWriter.add_embedding
            # noisy_batch = noisy_batch.detach()
            # recon_batch = recon_batch.detach()
            # last_batch_plt = matplotlib.figure.Figure()  # read https://github.com/wookayin/tensorflow-plot
            # ax = last_batch_plt.add_subplot(1, 1, 1)
            # ax.scatter(orig_batch[:, 0], orig_batch[:, 1], label="original")
            # ax.scatter(noisy_batch[:, 0], noisy_batch[:, 1], label="noisy")
            # ax.scatter(recon_batch[:, 0], recon_batch[:, 1], label="reconstructed")
            # ax.legend(loc="upper left")
            # self.logger.experiment.add_figure(f"original last batch, epoch {self.current_epoch}", last_batch_plt)
            # tb.add_embedding(orig_batch, global_step=self.current_epoch, metadata=label_batch)
        self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def _plot_batches(self, orig_batch, noisy_batch, label_batch, batch_idx, recon_batch, mu, logvar):
        # orig_batch_view = orig_batch.reshape(-1, self._config.case.frames_per_pd_sample,
        # len(self._config.case.feature_list))
        # plt.figure()
        # plt.plot(orig_batch_view[11, :, 0].detach().cpu().numpy(), label="feature 0")
        # plt.legend(loc="upper left")
        # plt.show()

        tb = self.logger.experiment
        if self._config.dataset == "mnist":
            orig_batch -= orig_batch.min()
            orig_batch /= orig_batch.max()
            recon_batch -= recon_batch.min()
            recon_batch /= recon_batch.max()

            orig_grid = torchvision.utils.make_grid(orig_batch.view(-1, 1, 28, 28))
            val_recon_grid = torchvision.utils.make_grid(recon_batch.view(-1, 1, 28, 28))

            tb.add_image("original_val", orig_grid, global_step=self.current_epoch)
            tb.add_image("reconstruction_val", val_recon_grid, global_step=self.current_epoch)

            label_img = orig_batch.view(-1, 1, 28, 28)
        elif self._config.dataset == "case":
            orig_batch_view = orig_batch.reshape(-1, self._config.case.frames_per_pd_sample,
                                                 len(self._config.case.feature_list)).transpose(1, 2)
            recon_batch_view = recon_batch.reshape(-1, self._config.case.frames_per_pd_sample,
                                                   len(self._config.case.feature_list)).transpose(1, 2)

            # plt.figure()
            # plt.plot(orig_batch_view[11, 0, :].detach().cpu().numpy())
            # plt.show()
            # pass

            n_samples = orig_batch_view.shape[0]
            n_plots = min(n_samples, 4)
            first_sample_idx = 0

            # TODO either plotting or data problem
            fig, axs = plt.subplots(n_plots, 1)
            for sample_idx in range(n_plots):
                for feature_idx, (orig_feature, recon_feature) in enumerate(

Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

Welcome to Vigges Developer Community for programmer and developer-Open, Learning and Share