# TD8: Dimensionality reduction with Deep Learning & TDA

In 2021,  [Carrière et al](https://arxiv.org/pdf/2010.08356) proposed a general framework to define and compute gradients for persistence-based functions, enabling to directly optimize persistence-based loss functions in deep learning architectures.  Concurrently, several works have demonstrated that combining deep learning with TDA-based losses can effectively learn meaningful low-dimensional embeddings of data ([Moor et la, 2021](https://arxiv.org/pdf/1906.00722), [Vandaele et al, 2022](https://arxiv.org/pdf/2110.09193)).

The goal of this lab is to implement an autoencoder to reduce the dimensionality of the Fashion-MNIST dataset to 2D. During the lab, you will either have to write the full code or complete the provided snippets. Some code snippets are also pre-written and can be executed directly, but you are welcome to re-implement everything if you prefer.


Requirements:
```bash
gudhi
numpy
scikit-learn
matplotlib
plotly
torch
torchvision
torch_topological
```


In [None]:
!pip install umap-learn
!pip install gudhi
!pip install numpy
!pip install scikit-learn
!pip install matplotlib
!pip install plotly
!pip install torch
!pip install torchvision
!pip install torch_topological

In [None]:
import os

import random
import plotly.graph_objects as go
import numpy as np
from sklearn.decomposition import PCA
from matplotlib import pyplot as plt
from umap import UMAP
import sklearn
from sklearn.utils import shuffle
from sklearn.metrics import accuracy_score

# -- deep learning packages
import torch
from torchvision import datasets
import matplotlib.pyplot as plt
import torchvision
import torch.nn as nn


# -- TDA packages
import gudhi
import gudhi.representations
from sklearn import decomposition
from sklearn.manifold import MDS
from sklearn.neighbors import KNeighborsClassifier
from gudhi.representations import DiagramSelector
from torch_topological.nn import VietorisRipsComplex
from torch_topological.nn import WassersteinDistance


# classes of the dataset 
CLASSES = [" T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]

In [None]:
def seed_everything(seed: int):
    """fix the seed for reproducibility."""
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.use_deterministic_algorithms(True)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42)

### Importing &  Exploring the dataset 

Fashion-MNIST is a dataset of Zalando's article images—consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes (T-shirt/top, Trouser, Pullover, Dress, Coat, Sandal, Shirt, Sneaker, Bag, Ankle boot). It shares the same image size and structure across both the train and test set.

Here is an overview of the dataset:

![fashionMNIST](https://github.com/zalandoresearch/fashion-mnist/blob/master/doc/img/fashion-mnist-sprite.png?raw=true)

Let's access the dataset using `torch`. We import the dataset composed of the training set in `training_data` and the test set in `test_data`. Each set is made of a list containing tuples (image, label). They are resptively of size 60000 and 10000.

In [None]:
training_data = datasets.FashionMNIST(root="data", download=True, train=True)
test_data = datasets.FashionMNIST(root="data", download=True, train=False)

Q.1 Display a few images, along with their labels.

In [None]:
plt.imshow(training_data[1000][0], cmap="gray")
plt.title(f"Label: {CLASSES[training_data[1000][1]]}") 

We are going to use flatten version of the images to obtain one vector per image. We will also divide the pixel values by 255 and normalize them with mean and standard deviation of 0.5.

In [None]:
X_test = [ (np.ndarray.flatten(np.array(img)) / 255 - 0.5) / 0.5 for img,_ in test_data]
y_test = [cat for _, cat in test_data]
X_train = [ (np.ndarray.flatten(np.array(img)) / 255 - 0.5) / 0.5 for img,_ in training_data]
y_train = [cat for _, cat in training_data]


X_train, y_train = shuffle(X_train, y_train, random_state=42)

In [None]:
idx = 10
plt.imshow(X_train[idx].reshape(28,28), cmap="gray")
plt.title(f"Label, {CLASSES[y_train[idx]]}")

Q.2 Visualize the test set using a PCA in 2d. The PCA should be fitted on the train set.

In [None]:
# Fit the method
reductor = PCA(n_components=2)
reductor.fit(X_train)
testing_reduction = reductor.transform(X_test)    
fig, ax = plt.subplots()
scatter = ax.scatter(testing_reduction[:, 0],
                    testing_reduction[:, 1],
                    c=y_test,
                    s=2,
                    alpha=0.3,
                    cmap='Set3')
fig.colorbar(scatter)
ax.set_title(f'PCA on Fashion MNIST')

### Training an Autoencoder

**An auto-encoder** is defined as the composition of two functions: $h\circ g$ ,
where $g : X → Z$ represents the encoder and $h: Z → X'$
represents the decoder. We denote the latent space by $Z := g(X)$, and the reconstruction by $X' := h(Z)$. The auto-encoder is trained to minimize the reconstruction error, i.e., the distance between the input and the output, $L(X, X')$.

Q.3 Now, let's implement the model `Autoencoder`. Complete the following code (in the sections marked by TODOs) to:
- define the encoder
- define the decoder
- complete the encode, decode and forward functions.

The encoder is composed of 3 layers of hidden neurons of decreasing size $ (500 − 250 - 2)$. The decoder is
a sequence of three layers of hidden neurons of increasing size $(250 − 500 - 784)$.
Between each linear layers, we use a `ReLU` activation function and `BatchNorm1d` (normalization of the layers' inputs by re-centering and re-scaling) for faster and more stable training. The final layer of the decoder uses a tanh activation function (such that the image of the activation matches the range of input images scaled between −1 and 1).


Notes:
- How are torch models structured ? 
A PyTorch model is typically structured as a subclass of torch.nn.Module, which defines the architecture and behavior of the model. Here's a brief outline of its structure:
1. Initialization (`__init__` method): Define the model's layers and components, such as linear layers, activation functions, using PyTorch modules (torch.nn.*).
2. Forward Pass (`forward` method): specify how the input data flows through the layers of the model.
3. Instantiation and Training: Once defined, the model is instantiated and used with a training loop.
```python
model = MyModel()
output = model(input_data)  # Calls the forward method
```

- The encoder can be defined as:
```python
encoder = nn.Sequential(
    nn.Linear(784, 500),
    nn.ReLU(),
    nn.BatchNorm1d(500),
    nn.Linear(500, 250),
    nn.ReLU(),
    nn.BatchNorm1d(250),
    nn.Linear(250, 2)
)
```
<!-- - Reshape is a custom layer defined as:
```python
class Reshape(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)
``` -->


In [None]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        # - start TODO: define the encoder
        self.encoder = nn.Sequential(
            nn.Linear(784, 500),
            nn.ReLU(),
            nn.BatchNorm1d(500),
            nn.Linear(500, 250),
            nn.ReLU(),
            nn.BatchNorm1d(250),
            nn.Linear(250, 2)
        )
        # - TODO: define the decoder
        self.decoder = nn.Sequential(
            nn.Linear(2, 250),
            nn.ReLU(),
            nn.BatchNorm1d(250),
            nn.Linear(250, 500),
            nn.ReLU(),
            nn.BatchNorm1d(500),
            nn.Linear(500, 784),
            nn.Tanh()
        )


    def encode(self, x):
        """Compute the latent representation using the encoder.
        
        Parameters: 
        -----------
        x: input batch [bs x (n_row * n_col)]

        Returns: 
        --------
            latent representations [bs x 2]
        """
        # - TODO: complete the encode function 
        return self.encoder(x)

    def decode(self, z):
        """Compute the reconstruction using the decoder.
        
        Parameters:
        -----------
        z: latent representation [bs x 2]

        Returns:
        --------
            reconstructed images [bs x (n_row * n_col)]
        """
        # - TODO: complete the decode function 
        return self.decoder(z)

    def forward(self, x):
        """Apply the autoencoder to a batch of input images.

        Parameters:
        -----------
            x: Batch of images with shape [bs x (n_row * n_col) ]

        Returns:
            (reconstructed images [bs x (n_row * n_col) ], latent representations [bs * 2])

        """
        # - TODO complete the forward function
        latent = self.encode(x)
        output = self.decode(latent)
        return output, latent

Q.4 Complete the `fit_model` function to train the model. The function takes as input the model, the training and test sets, the optimizer and the loss function. The function should return the trained model.

There are two part to complete: 
- in the training loop: retrive the reconstruction and latent representation given by the model, compute the loss and store it.
- in the validation loop: iterate over the test set to ensure the training improves also on unseen data. Retrive the reconstruction and latent representation given by the model, compute and store the test loss

**Notes:**

Training a deep neural network usually involves the following steps:
- Forward Pass: Input data is passed through the network
- Loss Calculation: The difference between the predictions and actual values is measured using a loss function.
- Backward Pass: Gradients of the loss with respect to the model parameters are calculated using backpropagation and gradient descent.
- Parameter Update: The gradients are used to adjust the parameters, minimizing the loss over time.

In practise, we call:
- Epoch: One full pass through the entire training dataset.
- Batch: subset of the training dataset used during one forward and backward pass of the neural network. Instead of processing the entire dataset at once (which can be computationally expensive), the dataset is divided into smaller groups of samples, called batches.
- Iterations: Number of updates per epoch, equal to (dataset size ÷ batch size).
- This cycle repeats over many epochs until the network converges (achieves good performance).


In [None]:
def fit_model(model, train, test, n_epochs=10, lr=1e-3, batch_size=128, criterion=nn.MSELoss()):
    """Training loop.

    Parameters
    ----------
    model : nn.Module
        The model to train.
    train : list of flatten training images
        The training dataset.
    test : list of flatten testing images
        The validation dataset.
    n_epochs : int
        The number of epochs.
    lr : float
        The learning rate.
    batch_size : int
        The batch size.
    criterion : nn.Module
        The loss function.

    Returns
    -------
    model : nn.Module
        The trained model.
    """
    # -- DATA SET-UP
    # - convert the data to tensors
    X_train = torch.tensor(train, dtype=torch.float)
    X_test = torch.tensor(test, dtype=torch.float)
    # - Create the train laoder
    # a dataloader helps manage and feed data to a model during training 
    # or evaluation in an efficient and convenient way by creating batch (small chunk of the dataset)
    # since we cannot pass all the data.
    train_loader = torch.utils.data.DataLoader(X_train, batch_size=batch_size, shuffle=True)
    # - Create the val loader
    test_loader = torch.utils.data.DataLoader(X_test, batch_size=batch_size, shuffle=False)
    # -- END DATA SET-UP

    # - initialize the optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    
    loss_train, loss_val = [], []

    # - iterate over the number of epochs
    for epoch in range(n_epochs):
        # - set the model to training.
        model.train()
        loss_step, loss_val_step = [], []

        # --- iterate over the training set
        for batch_idx, data in enumerate(train_loader):

            # - clear the past gradients
            optimizer.zero_grad()

            # ---- start TODO ----
            # retrive the reconstruction and latent representation given by the model 
            # and compute the loss.
            output, latent = model(data)
            if isinstance(criterion, nn.MSELoss):
                loss = criterion(data, output)
            else:
                loss = criterion(data, latent, output)
            loss_step.append(loss.item())
            # ----- end TODO ----

            # - computes the gradients of the loss
            loss.backward()
            # - update the weigths
            optimizer.step()
            # - display information
            if batch_idx % 100 == 0:
                print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item()}')

        # --- iterate over the validation set to see how the training is going
        # - set the model mode to eval 
        model.eval()
        # - do not compute the gradients here, since the weights won't be updated
        with torch.no_grad():
            # --- start TODO ---
            # iterate over the test set
            # retrive the reconstruction and latent representation given by the model 
            # compute and store the loss
            for data in test_loader:
                output, latent = model(data)
                if isinstance(criterion, nn.MSELoss):
                    val_loss = criterion(output, data)
                else:
                    val_loss = criterion(data, latent, output)
                loss_val_step.append(val_loss.item())
            # --- end TODO ---

        # - compute the average loss per epoch
        loss_train.append(np.mean(loss_step))
        loss_val.append(np.mean(loss_val_step))

        # - display val loss
        print(f'Epoch: {epoch}, Val Loss: {val_loss.item()}')

    # - plots the losses
    fig, ax = plt.subplots()
    ax.plot(loss_train, label='Train')
    ax.plot(loss_val, label='Validation')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.legend()
    plt.show()
    return model

Q.5 Train the model with the reconstruction loss between $X$ and $X'$, we will use the mean square error (MSE, `nn.MSELoss()`)

You can use with the following hyperparameters:
- `n_epochs=10`
- `lr=1e-4`
- `batch_size=128`

It takes around approx. 4min to train the model with the given hyperparameters on a google colab.

In [None]:
auto_encoder = fit_model(AutoEncoder(), X_train, X_test, n_epochs=10, lr=1e-4, batch_size=128, criterion=nn.MSELoss())

Q.6 Visualize the latent space given by the autoencoder on the test set, and display some reconstructed images. 

Note: 
- you can get help from the data set-up and validation loop implemented in `fit_model` (Q.4)
- you can adjust the  axis limits of your plot for better visualization

In [None]:
def test_model(model, X_test, y_test, name='autoencoder', y_lim=(-25, 25), x_lim=(-10, 10)):
    model.eval()

    X_test = torch.tensor(X_test, dtype=torch.float)

    test_loader = torch.utils.data.DataLoader(X_test, batch_size=128, shuffle=False)
    latents, reconstructs = [], []
    model.eval()
    with torch.no_grad():
        for test_batch in test_loader:
            predicted, latent = model.forward(test_batch)
            latents.append(latent)
            reconstructs.append(predicted)

    latents = torch.cat(latents, dim=0).numpy()
    reconstructs = torch.cat(reconstructs, dim=0).numpy()

    fig, ax = plt.subplots()
    scatter = ax.scatter(latents[:, 0], latents[:, 1], c=y_test, cmap='Set3', s=2, alpha=0.4)
    ax.set_ylim(y_lim)
    ax.set_xlim(x_lim)
    fig.colorbar(scatter)
    ax.set_title(name)
    return fig, reconstructs


In [None]:
fig, reconstructs = test_model(auto_encoder, X_test, y_test, name='Autoencoder', y_lim=(-4, 4), x_lim=(-4,2))

In [None]:
plt.imshow(reconstructs[10].reshape(28, 28), cmap="gray")
plt.title(f"Visualization of a reconstructed image based on the latent representation with label {y_test[10]}")

Q.7 Implement the topological regularized loss by filling-in the TODO in the following code.

The loss is defined as $L = \texttt{MSE}(X,X') + \mu * W_1(D_X, D_Z)$, with $W_1(D_X, D_Z)$ the 1-Wasserstein distance between the 0 or 1-dimensional persistence diagrams based on  Vietoris-Rips filtration of the input space and the latent space.
It should take as parameters: the dimension of the PH computed, the coefficent  $\mu$

Note: 
- Custom losses follow the same template than a torch model (init method, and forward pass)
- You can use the `torch_topological` package to compute the persistence diagrams (`torch_topological.nn.VietorisRipsComplex`) and the wasserstein distances (`torch_topological.nn.WassersteinDistance`)

<!-- With `tf`, use the `gudhi` package and with `pytorch` you can use torch-topological to compute the persistence diagrams and the distances. -->

In [None]:
class TopologicalRegularizer(nn.Module):
    def __init__(self, dim, p, mu):
        super().__init__()
        # - TODO: comlplete the init
        # by intiliazing the different components 
        # (i.e. MSE Loss, VR complex, …)
        self.persistence_latent = VietorisRipsComplex(dim=dim, p=p)
        self.persistence_data = VietorisRipsComplex(dim=dim, p=p)
        self.distance = WassersteinDistance(p=1)
        self.mse = nn.MSELoss()
        self.dim = dim
        self.mu = mu

    def forward(self, input, latent, reconstruction):
        """
        Forward of the loss

        Parameters:
        ----------
        input: [bs x (n_rows * n_cols)]
            original image
        latent: [bs x 2]
            latent space given by the model
        reconstruction: [bs x (n_rows * n_cols)]
            reconstructed images
        
        """
        # - TODO: complete the forward method
        loss = torch.tensor(0.0).to(input.device)
        # - topo reg.
        ph_latent = self.persistence_latent(latent)
        ph_original = self.persistence_data(input)
        loss += self.mu * self.distance(ph_latent[self.dim], ph_original[self.dim])
        # - mse
        loss += self.mse(input, reconstruction)
        return loss

Q.8 Train the model with the topological regularizer with the following params:
- `n_epochs=10`
- `lr=1e-4`
- `batch_size=128`
- `mu=0.0005` 
- `dim=0`
  
Once trained, display the resulting 2d latent space of the test set.

In [None]:
topological_criterion = TopologicalRegularizer(dim=0, p=2, mu=0.0005)
trained_topo_ae = fit_model(AutoEncoder(), X_train, X_test, n_epochs=10, lr=1e-4, batch_size=128, criterion=topological_criterion)

In [None]:
fig_topo, reconstructs = test_model(trained_topo_ae, X_test, y_test, name='Topological Autoencoder', y_lim=(-100, 100), x_lim=(-100, 100))

Q.9 Train the model with the topological regularizer with the following params:
- `n_epochs=10`
- `lr=1e-4`
- `batch_size=128`
- `mu=0.0001` 
- `dim=1`
  
Once trained, display the resulting 2d latent space of the test set.

In [None]:
topological_criterion_1 = TopologicalRegularizer(dim=1, p=2, mu=0.0001)
trained_topo_ae_1 = fit_model(AutoEncoder(), X_train, X_test, n_epochs=10, lr=1e-4, batch_size=128, criterion=topological_criterion_1)

In [None]:
fig_topo_1, reconstructs_1 = test_model(trained_topo_ae_1, X_test, y_test, name='Topological Autoencoder', y_lim=(-2, 2), x_lim=(-2, 2))