Image Clustering: With rotation

Warning: Agglomerative clustering scales badly with dataset size. This leads to high memory usage (about 360 GB on Kale).

[1]:
import warnings
from abc import ABC, abstractmethod

import matplotlib.pyplot as plt
import numpy as np
from hdbscan import HDBSCAN
from numba.errors import NumbaDeprecationWarning, NumbaWarning
from numpy.random import RandomState
from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score

from dpemu import runner
from dpemu.dataset_utils import load_mnist_unsplit
from dpemu.filters.image import Rotation
from dpemu.ml_utils import reduce_dimensions
from dpemu.nodes import Array
from dpemu.nodes.series import Series
from dpemu.plotting_utils import visualize_best_model_params, visualize_scores, visualize_classes, \
    print_results_by_model

warnings.simplefilter("ignore", category=NumbaDeprecationWarning)
warnings.simplefilter("ignore", category=NumbaWarning)
/wrk/users/thalvari/dpEmu/venv/lib/python3.7/site-packages/sklearn/externals/six.py:31: DeprecationWarning: The module is deprecated in version 0.21 and will be removed in version 0.23 since we've dropped support for Python 2.7. Please rely on the official version of six (https://pypi.org/project/six/).
  "(https://pypi.org/project/six/).", DeprecationWarning)
/wrk/users/thalvari/dpEmu/venv/lib/python3.7/site-packages/sklearn/externals/joblib/__init__.py:15: DeprecationWarning: sklearn.externals.joblib is deprecated in 0.21 and will be removed in 0.23. Please import this functionality directly from joblib, which can be installed with: pip install joblib. If this warning is raised when loading pickled models, you may need to re-serialize those models with scikit-learn 0.21+.
  warnings.warn(msg, category=DeprecationWarning)
[2]:
def get_data():
    return load_mnist_unsplit(70000)
[3]:
def get_err_root_node():
    err_img_node = Array(reshape=(28, 28))
    err_root_node = Series(err_img_node)
    err_img_node.addfilter(Rotation("min_angle", "max_angle"))
    return err_root_node
[4]:
def get_err_params_list():
    angle_steps = np.linspace(0, 84, num=8)
    err_params_list = [{"min_angle": -a, "max_angle": a} for a in angle_steps]
    return err_params_list
[5]:
class Preprocessor:
    def __init__(self):
        self.random_state = RandomState(42)

    def run(self, _, data, params):
        reduced_data = reduce_dimensions(data, self.random_state)
        return None, reduced_data, {"reduced_data": reduced_data}
[6]:
class AbstractModel(ABC):

    def __init__(self):
        self.random_state = RandomState(42)

    @abstractmethod
    def get_fitted_model(self, data, params):
        pass

    def run(self, _, data, params):
        labels = params["labels"]
        fitted_model = self.get_fitted_model(data, params)
        return {
            "AMI": round(adjusted_mutual_info_score(labels, fitted_model.labels_, average_method="arithmetic"), 3),
            "ARI": round(adjusted_rand_score(labels, fitted_model.labels_), 3),
        }


class KMeansModel(AbstractModel):

    def __init__(self):
        super().__init__()

    def get_fitted_model(self, data, params):
        labels = params["labels"]
        n_classes = len(np.unique(labels))
        return KMeans(n_clusters=n_classes, random_state=self.random_state).fit(data)


class AgglomerativeModel(AbstractModel):

    def __init__(self):
        super().__init__()

    def get_fitted_model(self, data, params):
        labels = params["labels"]
        n_classes = len(np.unique(labels))
        return AgglomerativeClustering(n_clusters=n_classes).fit(data)


class HDBSCANModel(AbstractModel):

    def __init__(self):
        super().__init__()

    def get_fitted_model(self, data, params):
        return HDBSCAN(
            min_samples=params["min_samples"],
            min_cluster_size=params["min_cluster_size"],
            core_dist_n_jobs=1
        ).fit(data)
[7]:
def get_model_params_dict_list(data, labels):
    n_data = data.shape[0]
    divs = [12, 15, 20, 30, 45, 65, 90]
    min_cluster_size_steps = [round(n_data / div) for div in divs]
    min_samples_steps = [1, 5, 10]
    return [
        {"model": KMeansModel, "params_list": [{"labels": labels}]},
        {"model": AgglomerativeModel, "params_list": [{"labels": labels}]},
        {"model": HDBSCANModel, "params_list": [{
            "min_cluster_size": min_cluster_size,
            "min_samples": min_samples,
            "labels": labels
        } for min_cluster_size in min_cluster_size_steps for min_samples in min_samples_steps]},
    ]
[8]:
def visualize(df, label_names, dataset_name, data):
    visualize_scores(
        df,
        score_names=["AMI", "ARI"],
        is_higher_score_better=[True, True],
        err_param_name="max_angle",
        title=f"{dataset_name} clustering scores with rotation",
    )
    visualize_best_model_params(
        df,
        model_name="HDBSCAN",
        model_params=["min_cluster_size", "min_samples"],
        score_names=["AMI", "ARI"],
        is_higher_score_better=[True, True],
        err_param_name="max_angle",
        title=f"Best parameters for {dataset_name} clustering"
    )
    visualize_classes(
        df,
        label_names,
        err_param_name="max_angle",
        reduced_data_column="reduced_data",
        labels_column="labels",
        cmap="tab10",
        title=f"{dataset_name} (n={data.shape[0]}) true classes with rotation"
    )
    plt.show()
[9]:
def main():
    data, labels, label_names, dataset_name = get_data()

    df = runner.run(
        train_data=None,
        test_data=data,
        preproc=Preprocessor,
        preproc_params=None,
        err_root_node=get_err_root_node(),
        err_params_list=get_err_params_list(),
        model_params_dict_list=get_model_params_dict_list(data, labels),
    )

    print_results_by_model(df, ["min_angle", "reduced_data", "labels"])
    visualize(df, label_names, dataset_name, data)
[10]:
main()
100%|██████████| 8/8 [11:45<00:00, 88.18s/it]
Agglomerative #1
AMI ARI max_angle time_err time_pre time_mod
0 0.888 0.819 0.0 5.404 445.732 185.305
1 0.871 0.804 12.0 5.886 465.016 183.202
2 0.805 0.707 24.0 6.258 448.705 184.230
3 0.701 0.542 36.0 5.708 454.822 186.735
4 0.689 0.530 48.0 6.299 461.554 185.352
5 0.607 0.419 60.0 16.806 448.424 185.581
6 0.580 0.398 72.0 6.675 452.017 188.154
7 0.522 0.342 84.0 16.912 447.885 186.312
HDBSCAN #1
AMI ARI max_angle min_cluster_size min_samples time_err time_pre time_mod
0 0.902 0.853 0.0 5833.0 1.0 5.404 445.732 2.263
1 0.901 0.852 0.0 5833.0 5.0 5.404 445.732 2.498
2 0.903 0.853 0.0 5833.0 10.0 5.404 445.732 2.572
3 0.902 0.853 0.0 4667.0 1.0 5.404 445.732 2.222
4 0.901 0.852 0.0 4667.0 5.0 5.404 445.732 2.456
5 0.903 0.853 0.0 4667.0 10.0 5.404 445.732 2.530
6 0.902 0.853 0.0 3500.0 1.0 5.404 445.732 2.171
7 0.901 0.852 0.0 3500.0 5.0 5.404 445.732 2.416
8 0.903 0.853 0.0 3500.0 10.0 5.404 445.732 2.516
9 0.902 0.853 0.0 2333.0 1.0 5.404 445.732 2.170
10 0.901 0.852 0.0 2333.0 5.0 5.404 445.732 2.448
11 0.903 0.853 0.0 2333.0 10.0 5.404 445.732 2.535
12 0.902 0.853 0.0 1556.0 1.0 5.404 445.732 2.201
13 0.901 0.852 0.0 1556.0 5.0 5.404 445.732 2.449
14 0.903 0.853 0.0 1556.0 10.0 5.404 445.732 2.591
15 0.902 0.853 0.0 1077.0 1.0 5.404 445.732 2.204
16 0.901 0.852 0.0 1077.0 5.0 5.404 445.732 2.453
17 0.903 0.853 0.0 1077.0 10.0 5.404 445.732 2.555
18 0.902 0.853 0.0 778.0 1.0 5.404 445.732 2.230
19 0.901 0.852 0.0 778.0 5.0 5.404 445.732 2.473
20 0.903 0.853 0.0 778.0 10.0 5.404 445.732 2.615
21 0.883 0.837 12.0 5833.0 1.0 5.886 465.016 2.122
22 0.882 0.836 12.0 5833.0 5.0 5.886 465.016 2.423
23 0.882 0.836 12.0 5833.0 10.0 5.886 465.016 2.611
24 0.883 0.837 12.0 4667.0 1.0 5.886 465.016 2.132
25 0.882 0.836 12.0 4667.0 5.0 5.886 465.016 2.400
26 0.882 0.836 12.0 4667.0 10.0 5.886 465.016 2.570
27 0.883 0.837 12.0 3500.0 1.0 5.886 465.016 2.133
28 0.882 0.836 12.0 3500.0 5.0 5.886 465.016 2.412
29 0.882 0.836 12.0 3500.0 10.0 5.886 465.016 2.596
30 0.883 0.837 12.0 2333.0 1.0 5.886 465.016 2.151
31 0.882 0.836 12.0 2333.0 5.0 5.886 465.016 2.423
32 0.882 0.836 12.0 2333.0 10.0 5.886 465.016 2.555
33 0.883 0.837 12.0 1556.0 1.0 5.886 465.016 2.166
34 0.882 0.836 12.0 1556.0 5.0 5.886 465.016 2.457
35 0.882 0.836 12.0 1556.0 10.0 5.886 465.016 2.475
36 0.883 0.837 12.0 1077.0 1.0 5.886 465.016 2.070
37 0.882 0.836 12.0 1077.0 5.0 5.886 465.016 2.362
38 0.882 0.836 12.0 1077.0 10.0 5.886 465.016 2.484
39 0.883 0.837 12.0 778.0 1.0 5.886 465.016 2.096
40 0.882 0.836 12.0 778.0 5.0 5.886 465.016 2.468
41 0.882 0.836 12.0 778.0 10.0 5.886 465.016 2.350
42 0.806 0.642 24.0 5833.0 1.0 6.258 448.705 2.170
43 0.865 0.819 24.0 5833.0 5.0 6.258 448.705 2.429
44 0.806 0.642 24.0 5833.0 10.0 6.258 448.705 2.599
45 0.836 0.747 24.0 4667.0 1.0 6.258 448.705 2.149
46 0.865 0.819 24.0 4667.0 5.0 6.258 448.705 2.403
47 0.862 0.816 24.0 4667.0 10.0 6.258 448.705 2.559
48 0.836 0.747 24.0 3500.0 1.0 6.258 448.705 2.152
49 0.865 0.819 24.0 3500.0 5.0 6.258 448.705 2.415
50 0.862 0.816 24.0 3500.0 10.0 6.258 448.705 2.576
51 0.836 0.747 24.0 2333.0 1.0 6.258 448.705 2.147
52 0.865 0.819 24.0 2333.0 5.0 6.258 448.705 2.420
53 0.862 0.816 24.0 2333.0 10.0 6.258 448.705 2.571
54 0.836 0.747 24.0 1556.0 1.0 6.258 448.705 2.153
55 0.865 0.819 24.0 1556.0 5.0 6.258 448.705 2.437
56 0.862 0.816 24.0 1556.0 10.0 6.258 448.705 2.589
57 0.836 0.747 24.0 1077.0 1.0 6.258 448.705 2.161
58 0.865 0.819 24.0 1077.0 5.0 6.258 448.705 2.460
59 0.862 0.816 24.0 1077.0 10.0 6.258 448.705 2.608
60 0.836 0.747 24.0 778.0 1.0 6.258 448.705 2.177
61 0.865 0.819 24.0 778.0 5.0 6.258 448.705 2.453
62 0.862 0.816 24.0 778.0 10.0 6.258 448.705 2.617
63 0.738 0.497 36.0 5833.0 1.0 5.708 454.822 2.295
64 0.738 0.497 36.0 5833.0 5.0 5.708 454.822 2.409
65 0.739 0.498 36.0 5833.0 10.0 5.708 454.822 2.567
66 0.738 0.497 36.0 4667.0 1.0 5.708 454.822 2.174
67 0.818 0.736 36.0 4667.0 5.0 5.708 454.822 2.459
68 0.820 0.736 36.0 4667.0 10.0 5.708 454.822 2.735
69 0.738 0.497 36.0 3500.0 1.0 5.708 454.822 2.196
70 0.818 0.736 36.0 3500.0 5.0 5.708 454.822 2.468
71 0.820 0.736 36.0 3500.0 10.0 5.708 454.822 2.626
72 0.819 0.736 36.0 2333.0 1.0 5.708 454.822 2.199
73 0.835 0.788 36.0 2333.0 5.0 5.708 454.822 2.614
74 0.837 0.789 36.0 2333.0 10.0 5.708 454.822 2.623
75 0.738 0.497 36.0 1556.0 1.0 5.708 454.822 2.195
76 0.835 0.788 36.0 1556.0 5.0 5.708 454.822 2.496
77 0.820 0.736 36.0 1556.0 10.0 5.708 454.822 2.597
78 0.738 0.497 36.0 1077.0 1.0 5.708 454.822 2.295
79 0.835 0.788 36.0 1077.0 5.0 5.708 454.822 2.558
80 0.837 0.789 36.0 1077.0 10.0 5.708 454.822 2.604
81 0.738 0.497 36.0 778.0 1.0 5.708 454.822 2.198
82 0.835 0.788 36.0 778.0 5.0 5.708 454.822 2.481
83 0.820 0.736 36.0 778.0 10.0 5.708 454.822 2.514
84 0.731 0.492 48.0 5833.0 1.0 6.299 461.554 2.113
85 0.730 0.492 48.0 5833.0 5.0 6.299 461.554 2.343
86 0.731 0.492 48.0 5833.0 10.0 6.299 461.554 2.464
87 0.731 0.492 48.0 4667.0 1.0 6.299 461.554 2.112
88 0.730 0.492 48.0 4667.0 5.0 6.299 461.554 2.343
89 0.731 0.492 48.0 4667.0 10.0 6.299 461.554 2.496
90 0.731 0.492 48.0 3500.0 1.0 6.299 461.554 2.143
91 0.805 0.726 48.0 3500.0 5.0 6.299 461.554 2.385
92 0.731 0.492 48.0 3500.0 10.0 6.299 461.554 2.468
93 0.812 0.761 48.0 2333.0 1.0 6.299 461.554 2.150
94 0.805 0.726 48.0 2333.0 5.0 6.299 461.554 2.421
95 0.731 0.492 48.0 2333.0 10.0 6.299 461.554 2.498
96 0.812 0.761 48.0 1556.0 1.0 6.299 461.554 2.181
97 0.802 0.739 48.0 1556.0 5.0 6.299 461.554 2.432
98 0.731 0.492 48.0 1556.0 10.0 6.299 461.554 2.454
99 0.731 0.492 48.0 1077.0 1.0 6.299 461.554 2.055
100 0.730 0.492 48.0 1077.0 5.0 6.299 461.554 2.327
101 0.731 0.492 48.0 1077.0 10.0 6.299 461.554 2.436
102 0.731 0.492 48.0 778.0 1.0 6.299 461.554 2.093
103 0.730 0.492 48.0 778.0 5.0 6.299 461.554 2.391
104 0.731 0.492 48.0 778.0 10.0 6.299 461.554 2.548
105 0.721 0.487 60.0 5833.0 1.0 16.806 448.424 2.179
106 0.722 0.487 60.0 5833.0 5.0 16.806 448.424 2.342
107 0.720 0.486 60.0 5833.0 10.0 16.806 448.424 2.457
108 0.721 0.487 60.0 4667.0 1.0 16.806 448.424 2.133
109 0.722 0.487 60.0 4667.0 5.0 16.806 448.424 2.374
110 0.720 0.486 60.0 4667.0 10.0 16.806 448.424 2.480
111 0.721 0.487 60.0 3500.0 1.0 16.806 448.424 2.138
112 0.722 0.487 60.0 3500.0 5.0 16.806 448.424 2.361
113 0.720 0.486 60.0 3500.0 10.0 16.806 448.424 2.494
114 0.721 0.487 60.0 2333.0 1.0 16.806 448.424 2.151
115 0.722 0.487 60.0 2333.0 5.0 16.806 448.424 2.388
116 0.720 0.486 60.0 2333.0 10.0 16.806 448.424 2.544
117 0.721 0.487 60.0 1556.0 1.0 16.806 448.424 2.147
118 0.722 0.487 60.0 1556.0 5.0 16.806 448.424 2.396
119 0.720 0.486 60.0 1556.0 10.0 16.806 448.424 2.487
120 0.721 0.487 60.0 1077.0 1.0 16.806 448.424 2.228
121 0.722 0.487 60.0 1077.0 5.0 16.806 448.424 2.435
122 0.720 0.486 60.0 1077.0 10.0 16.806 448.424 2.519
123 0.721 0.487 60.0 778.0 1.0 16.806 448.424 2.202
124 0.722 0.487 60.0 778.0 5.0 16.806 448.424 2.333
125 0.720 0.486 60.0 778.0 10.0 16.806 448.424 2.435
126 0.710 0.482 72.0 5833.0 1.0 6.675 452.017 2.060
127 0.709 0.481 72.0 5833.0 5.0 6.675 452.017 2.336
128 0.708 0.481 72.0 5833.0 10.0 6.675 452.017 2.434
129 0.710 0.482 72.0 4667.0 1.0 6.675 452.017 2.066
130 0.709 0.481 72.0 4667.0 5.0 6.675 452.017 2.339
131 0.708 0.481 72.0 4667.0 10.0 6.675 452.017 2.443
132 0.699 0.463 72.0 3500.0 1.0 6.675 452.017 2.076
133 0.698 0.462 72.0 3500.0 5.0 6.675 452.017 2.344
134 0.698 0.462 72.0 3500.0 10.0 6.675 452.017 2.471
135 0.700 0.463 72.0 2333.0 1.0 6.675 452.017 2.103
136 0.700 0.463 72.0 2333.0 5.0 6.675 452.017 2.363
137 0.700 0.463 72.0 2333.0 10.0 6.675 452.017 2.460
138 0.700 0.463 72.0 1556.0 1.0 6.675 452.017 2.113
139 0.700 0.463 72.0 1556.0 5.0 6.675 452.017 2.382
140 0.700 0.463 72.0 1556.0 10.0 6.675 452.017 2.488
141 0.700 0.463 72.0 1077.0 1.0 6.675 452.017 2.098
142 0.700 0.463 72.0 1077.0 5.0 6.675 452.017 2.382
143 0.700 0.463 72.0 1077.0 10.0 6.675 452.017 2.477
144 0.679 0.431 72.0 778.0 1.0 6.675 452.017 2.098
145 0.700 0.463 72.0 778.0 5.0 6.675 452.017 2.358
146 0.700 0.463 72.0 778.0 10.0 6.675 452.017 2.470
147 0.405 0.111 84.0 5833.0 1.0 16.912 447.885 2.167
148 0.405 0.111 84.0 5833.0 5.0 16.912 447.885 2.388
149 0.405 0.111 84.0 5833.0 10.0 16.912 447.885 2.536
150 0.405 0.111 84.0 4667.0 1.0 16.912 447.885 2.131
151 0.405 0.111 84.0 4667.0 5.0 16.912 447.885 2.422
152 0.405 0.111 84.0 4667.0 10.0 16.912 447.885 2.556
153 0.405 0.111 84.0 3500.0 1.0 16.912 447.885 2.146
154 0.405 0.111 84.0 3500.0 5.0 16.912 447.885 2.409
155 0.405 0.111 84.0 3500.0 10.0 16.912 447.885 2.557
156 0.405 0.111 84.0 2333.0 1.0 16.912 447.885 2.147
157 0.405 0.111 84.0 2333.0 5.0 16.912 447.885 2.461
158 0.405 0.111 84.0 2333.0 10.0 16.912 447.885 2.579
159 0.405 0.111 84.0 1556.0 1.0 16.912 447.885 2.178
160 0.405 0.111 84.0 1556.0 5.0 16.912 447.885 2.457
161 0.405 0.111 84.0 1556.0 10.0 16.912 447.885 2.562
162 0.405 0.111 84.0 1077.0 1.0 16.912 447.885 2.110
163 0.405 0.111 84.0 1077.0 5.0 16.912 447.885 2.349
164 0.405 0.111 84.0 1077.0 10.0 16.912 447.885 2.478
165 0.405 0.111 84.0 778.0 1.0 16.912 447.885 2.105
166 0.405 0.111 84.0 778.0 5.0 16.912 447.885 2.381
167 0.405 0.111 84.0 778.0 10.0 16.912 447.885 2.368
KMeans #1
AMI ARI max_angle time_err time_pre time_mod
0 0.886 0.817 0.0 5.404 445.732 1.163
1 0.863 0.794 12.0 5.886 465.016 1.379
2 0.801 0.702 24.0 6.258 448.705 1.382
3 0.765 0.654 36.0 5.708 454.822 1.749
4 0.656 0.492 48.0 6.299 461.554 2.179
5 0.640 0.474 60.0 16.806 448.424 2.301
6 0.567 0.386 72.0 6.675 452.017 2.628
7 0.536 0.356 84.0 16.912 447.885 2.123
../_images/case_studies_Image_Clustering_With_Rotation_20_7.png
../_images/case_studies_Image_Clustering_With_Rotation_20_8.png
../_images/case_studies_Image_Clustering_With_Rotation_20_9.png

The notebook for this case study can be found here.