# TD7: Machine Learning with TDA

The goal of this lab session is to use persistent homology to derive stable signatures for 3D shape classification.


Throughout this lab, we will use `Gudhi`. It provides functions to vectorize persistence diagrams with `Scikit-Learn` estimator-like classes, that is, with classes that have `fit`, `transform`, and `fit_transform` methods, see [this article](https://arxiv.org/pdf/1309.0238.pdf) for more details.


Requirements:
```bash
gudhi
numpy
scikit-learn
matplotlib
plotly
```


In [None]:
!pip install gudhi
!pip install numpy
!pip install scikit-learn
!pip install matplotlib
!pip install plotly

In [None]:
import os

import sklearn
import numpy as np
import plotly.graph_objects as go
from sklearn.decomposition import PCA
from matplotlib import pyplot as plt
from sklearn.metrics import accuracy_score
from sklearn import decomposition
from sklearn.manifold import MDS
from sklearn.neighbors import KNeighborsClassifier

# -- TDA packages
import gudhi
import gudhi.representations
from gudhi.representations import DiagramSelector

In [None]:
def seed_everything(seed: int):
    """fix the seed for reproducibility."""
    import random, os
    import numpy as np
    import torch

    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42)

## I.0. Preprocessing & exploration of the dataset

### Q.0. Download & set-up the access to the data

#### For Colab users

Q.0.1 If you use **colab**:

- *Mount your Google Drive* to allow Colab to access files in your Drive.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


- *Access the [folder](https://drive.google.com/drive/folders/1oW5FTPNbIRPjFzlhzFLYRBkZ9iRMw3Q6?usp=drive_link)* with the data, this will automatically add it to your 'shared folder'
- *Locate the Shared Folder* which is tipically stored under `My Drive > Shared with me`.
- *Add the Shared Folder to 'My Drive'*.
Drag & drop the folder in 'My Drive' folder. (Only items in "My Drive" or folders youâ€™ve added to your Drive can be accessed directly.)
- The folder with the data should now be in *My Drive/LabeledDB_new*

In [None]:
dataset_path = '/content/drive/My Drive/LabeledDB_new'

#### For local users

Q.0.1 If you use your **own machine**
- Download the data from [here](https://drive.google.com/file/d/1LnRsKQTEMLjSMtt5J6cvkIzY7YtxNoML/view?usp=sharing)
- Extract it
- Save its path in the `dataset_path` variable

In [None]:
dataset_path = "to complete"

#### Read the data

Q.0.2 Verify everyting is set-up properly by listing the subdirectories in the file (it should contain 10 subdirectories, one for each class)


In [None]:
os.listdir(dataset_path)

Q.1 Write a function `read_data`, which takes as input the path to a shape and returns two `NumPy` arrays containing respectively the vertices and the faces.


Note. The dataset in split in 10 categories (Ant, Hand, Plier, Chair, Vase, Bird, FourLeg, Table, Airplane, Octopus), each category having its own folder. Inside each folder, some 3D shapes (i.e., 3D triangulations) are provided in [`.off`](https://en.wikipedia.org/wiki/OFF_(file_format)) format.

Each file are structured as follows:
- OFF is the standard signature for .off files,
- (n_v, n_f, _) where n_v and n_f are respectively the numbers of vertices and triangles in the mesh,
- the vertices (x_n y_n z_n),  are the coordinates of the nth vertex
- the faces (i_m j_m k_m) are the IDs of the three vertices representing the mth triangle. For instance, 3 14 0 5 denotes the triangle that is given by the 14th, 0th and 5th vertices.



In [None]:
def read_data(path):
    """
    Read a .off file and return the vertices and faces
    as numpy arrays given the path.

    Parameters
    ----------
    path: str
        Path to the shape file.

    Returns
    -------
    vertices: numpy array
        Array of vertices.
    faces: numpy array
        Array of faces.
    """
    # TODO
    pass

Q.2 Get the vertices and faces of the shape of your choice.

Q.3 Visualize a 3D shape. Write a function `plot_3d_shape` which takes as input the vertices and faces of a shape and plots it using the `plotly` library (with `go.Mesh3d`). Visualize the shape (selected in Q.2).



In [None]:
def plot_3d_shape(vertices, faces):
    """
    Plot a mesh given its vertices and faces.

    Parameters
    ----------
    vertices: numpy array
        Vertices of the mesh.
    faces: numpy array
        Faces of the mesh

    """
    # TODO
    pass


## I.1. Compute the persistence diagrams

Let's now compute the persistence diagrams for the 3D shapes.

Q.4 Write a function `get_simplex_tree_from_faces` function that builds a simplex tree from the faces of a given 3D shape triangulation.

Note. In `Gudhi`, (filtered) simplicial complexes are encoded through a data structure called simplex tree.
See the [`gudhi.SimplexTree()`](https://gudhi.inria.fr/python/latest/simplex_tree_ref.html) for a complete list of functionalities

Here is a very simple example illustrating the use of `gudhi.SimplexTree()` to represent a simplicial complex:
```python
st = gudhi.SimplexTree() # create a simplex tree

# simplices can be inserted one by one
# vertices are indexed by integers
if st.insert([0,1]):
  print("first simplex inserted!")

st.insert([1,2])
st.insert([2,3])
st.insert([3,0])

filtration = st.get_filtration() # Get a list with all the simplices
# Inserting an edge automatically insert its vertices (if they were
# not already in the complex)
for simplex in filtration:
  print(simplex)

# insert the 2-skeleton giving some filtration values to the faces
st.insert([0,1,2], filtration=0.1)
st.insert([1,2,3], filtration=0.2)

# if you add a new simplex with a given filtration values, all its faces that
# were not in the complex before are added with the same filtration value
st.insert([2,3,4],filtration=0.7)
```

In [None]:
def get_simplex_tree_from_faces(faces):
    """Bluids a simplex tree from a list of faces.

    Parameters
    ----------
    faces : numpy array
        faces of the shape

    Returns
    -------
    simplex_tree : gudhi.SimplexTree
        The simplex tree.
    """
    # TODO
    pass

Q.5 Write a function `compute_persistence_diagram` which takes as input the path of the shape and returns the persistence diagram of the shape in **degree 1.**

We will use a lower-star filtration based on the eccentricity of the vertices. The eccentricity of a vertex is the maximum distance to any other vertex in the shape.

You can follow the steps:
- use `read_path`, to access the vertices and faces of the shape
- use `get_simplex_tree_from_faces`, to get the simplex tree
- assign the filtration value to each vertex
- use `make_filtration_non_decreasing` to ensure the filtration is non-decreasing
- compute the persistence

In [None]:
def compute_persistence_diagram(path):
    """
    Compute the persistence diagram of a shape.

    Parameters
    ----------
    path: path where the .off file of the shape is stored.
    """
    # TODO
    pass

Q.5.1 Compute the persistence diagram of your choice and display it. You can use `gudhi.plot_persistence_diagram`.

Q.6 Compute the persistence diagrams of all shapes. Write the function `get_persistence_diagrams` which takes as input the path  where the shapes are stored, and returns the `persistence diagrams` of all shapes as well as the corresponding labels and ids. Run the function on all the shapes.

In [None]:
def get_persistence_diagrams(path):
    """
    Get the persistence of a diagram.

    Parameters
    ----------
    path: str
        Path to the shape file.

    Returns
    -------
    diagrams: list
        List of persistence diagrams.
    labels: list
        List of labels.
    ids: list
        List of ids. (can be of the frorm class + id of the shape)
    """
    # TODO
    pass

## II.2. Bottleneck distance and first classification

It is not convenient to use persistence diagrams directly as a featurization for machine learning purposes. As a first step, we will use the bottleneck distance to compare persistence diagrams, and then use $k$-nearest neighbors classifier to predict the category of a shape.

Q.7 Pick a specific persistence diagram and use `DiagramSelector(use=True, point_type='finite')` to remove its points with infinite coordinates.

Q.8 Pick another Persistent Diagram, remove the infinit points, and compute the bottleneck distance (`gudhi.representations.BottleneckDistance(epsilon=.001)`)  between them.

Q.9 Apply `MDS` on the bottleneck distance matrices. Display the resulting 2D latent space with the labels.

Note. Don't forget to remove the infinite points of the diagrams, and to encode the labels with value between 0 and n_classes-1. You can use:
```python
import sklearn.preprocessing

le = sklearn.preprocessing.LabelEncoder().fit(labels)
integer_labels = le.transform(labels)
label_indices = [(l,le.classes_[l],np.argwhere(integer_labels==l).ravel()) for l in range(integer_labels.max()+1)]
```

Now, let's classify the shapes using the bottleneck distance between the persistence diagrams and a $k$-nearest neighbors classifier!

Q.10 Shuffle the data (use the filtered diagrams and encoded labels from the Q6) and create a random 80/20 train/test split.

Q.11 Define, train and test a $k$-nearest neighbors classifier.
(Use as input direclty the bottleneck distance between the persistence diagrams, set `metric='precomputed'` in the `KNeighborsClassifier`)

## II.3. Explore vectorization of persistence diagrams

Since it is not convenient to use persistence diagrams directly for machine learning purposes, in this section we will explore different ways to vectorize persistence diagrams. A vecotorization is a map $\Phi: D \rightarrow H$ sending persistence diagrams into a Hilbert space, or equivalently,  a symmetric kernel function $k:D \times D \rightarrow \mathbb{R}$ such that $k(D,D')=\langle \Phi(D),\Phi(D')\rangle$.

For each featurization mentioned below some parameters are suggested but we recommend you to play with them and infer their influence on the ouput in order to get some intuition.


Q.12 Pick a specific persistence diagram and use `DiagramSelector(use=True, point_type='finite')` to remove its points with infinite coordinates.

**Silhouette.**  Silhouette are a variation of [persistence landscape](https://www.jmlr.org/papers/volume16/bubenik15a/bubenik15a.pdf). Persistence landscapes are obtained by rotating the persistence diagram by $-\pi/4$
(so that the diagonal becomes the $x$-axis), and then putting tent functions on each point. The $k$ -th landscape is then defined as the $k$-th largest value among all these tent functions. It is eventually turned into a vector by evaluating it on a bunch of uniformly sampled points on the $x$-axis.
Silhouettes take a weighted average of these tent functions instead. Here, we weight each tent function by the distance of the corresponding point to the diagonal.

Q.13 Compute the silhouette of the chosen diagram. Use `gudhi.representations.Silhouette` with `resolution=1000` and use `weight` to weight each tent given the distance to the diagonal.

**Sliced Wasserstein Kernel** [was introduced in 2017](https://proceedings.mlr.press/v70/carriere17a/carriere17a.pdf). Sliced Wasserstein kernels use the Sliced Wasserstein approximation of the Wasserstein distance to define a new kernel for PD.

Q.14 Compute the sliced wasserstein kernel of a pair of diagrams. You can use `gudhi.representations.SlicedWassersteinKernel` and for the hparams you can start by using:
```num_directions=100, bandwidth=1.0```.


Q.15 Apply dimensionality reduction techniques on the explicit maps. Display the resulting 2D latent space with the labels.

Is there any method that looks better in separating the categories, at least by eye? Compare also with the previously 2D latent space obtained when using the bottleneck distance.

Note. Don't forget to remove the infinite points of the diagrams. For the Sliced Wasserstein Kernel featurization use `KernelPCA`. To encode target the labels with value between 0 and n_classes-1, you can use:
```python
import sklearn.preprocessing

le = sklearn.preprocessing.LabelEncoder().fit(labels)
integer_labels = le.transform(labels)
label_indices = [(l,le.classes_[l],np.argwhere(integer_labels==l).ravel()) for l in range(integer_labels.max()+1)]
```

## II.3. Classification

Q.16 Let's try classification now on featurization now! Shuffle the data, and create a random 80/20 train/test split.

Q.17 Let's classify the PD using Sliced Wasserstein Kernels, Silhouettes and simple classifiers!

Define a `Pipeline` with four estimators: one for selecting the finite persistence diagram points, one for scaling (or not) the persistence diagrams (with `DiagramScaler`), one for vectorizing persistence diagrams, and one for performing the final prediction. See the [documentation](https://scikit-learn.org/stable/modules/compose.html#combining-estimators).

Note. Guhdi provides estimator-like classes: they can be integrated flawlessly in a `Pipeline` of `Scikit-Learn` for model selection and cross-validation. A `Pipeline` is itself an estimator, and is initialized as with a list of estimators. It will just sequentially apply the `fit_transform` methods of the estimators in the list.

For the Sliced Wasserstein kernel featurization, use `SVC` with `kernel='precomputed'` in the `Pipeline` as an estimator.

Q.17.1 Now, define a grid of parameter (for the different featurizations and estimators) that will be used in cross-validation.

Q.17.2 Define and train the model.

Q.17.3. Check the parameters that were chosen during model selection, and evaluate your model on the test set. Compare the results with one obtained using the Bottleneck distance + $k$-NN.