Image Clustering: Added noise

Warning: Agglomerative clustering scales badly with dataset size. This leads to high memory usage (about 360 GB on Kale).

[1]:
import warnings
from abc import ABC, abstractmethod

import matplotlib.pyplot as plt
import numpy as np
from hdbscan import HDBSCAN
from numba.errors import NumbaDeprecationWarning, NumbaWarning
from numpy.random import RandomState
from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score

from dpemu import runner
from dpemu.dataset_utils import load_mnist_unsplit
from dpemu.filters.common import GaussianNoise, Clip
from dpemu.ml_utils import reduce_dimensions
from dpemu.nodes import Array
from dpemu.nodes.series import Series
from dpemu.plotting_utils import visualize_best_model_params, visualize_scores, visualize_classes, \
    print_results_by_model

warnings.simplefilter("ignore", category=NumbaDeprecationWarning)
warnings.simplefilter("ignore", category=NumbaWarning)
/wrk/users/thalvari/dpEmu/venv/lib/python3.7/site-packages/sklearn/externals/six.py:31: DeprecationWarning: The module is deprecated in version 0.21 and will be removed in version 0.23 since we've dropped support for Python 2.7. Please rely on the official version of six (https://pypi.org/project/six/).
  "(https://pypi.org/project/six/).", DeprecationWarning)
/wrk/users/thalvari/dpEmu/venv/lib/python3.7/site-packages/sklearn/externals/joblib/__init__.py:15: DeprecationWarning: sklearn.externals.joblib is deprecated in 0.21 and will be removed in 0.23. Please import this functionality directly from joblib, which can be installed with: pip install joblib. If this warning is raised when loading pickled models, you may need to re-serialize those models with scikit-learn 0.21+.
  warnings.warn(msg, category=DeprecationWarning)
[2]:
def get_data():
    return load_mnist_unsplit(70000)
[3]:
def get_err_root_node():
    err_img_node = Array(reshape=(28, 28))
    err_root_node = Series(err_img_node)
    err_img_node.addfilter(GaussianNoise("mean", "std"))
    err_img_node.addfilter(Clip("min_val", "max_val"))
    return err_root_node
[4]:
def get_err_params_list(data):
    min_val = np.amin(data)
    max_val = np.amax(data)
    std_steps = np.linspace(0, max_val, num=8)
    err_params_list = [{"mean": 0, "std": std, "min_val": min_val, "max_val": max_val} for std in std_steps]
    return err_params_list
[5]:
class Preprocessor:
    def __init__(self):
        self.random_state = RandomState(42)

    def run(self, _, data, params):
        reduced_data = reduce_dimensions(data, self.random_state)
        return None, reduced_data, {"reduced_data": reduced_data}
[6]:
class AbstractModel(ABC):

    def __init__(self):
        self.random_state = RandomState(42)

    @abstractmethod
    def get_fitted_model(self, data, params):
        pass

    def run(self, _, data, params):
        labels = params["labels"]
        fitted_model = self.get_fitted_model(data, params)
        return {
            "AMI": round(adjusted_mutual_info_score(labels, fitted_model.labels_, average_method="arithmetic"), 3),
            "ARI": round(adjusted_rand_score(labels, fitted_model.labels_), 3),
        }


class KMeansModel(AbstractModel):

    def __init__(self):
        super().__init__()

    def get_fitted_model(self, data, params):
        labels = params["labels"]
        n_classes = len(np.unique(labels))
        return KMeans(n_clusters=n_classes, random_state=self.random_state).fit(data)


class AgglomerativeModel(AbstractModel):

    def __init__(self):
        super().__init__()

    def get_fitted_model(self, data, params):
        labels = params["labels"]
        n_classes = len(np.unique(labels))
        return AgglomerativeClustering(n_clusters=n_classes).fit(data)


class HDBSCANModel(AbstractModel):

    def __init__(self):
        super().__init__()

    def get_fitted_model(self, data, params):
        return HDBSCAN(
            min_samples=params["min_samples"],
            min_cluster_size=params["min_cluster_size"],
            core_dist_n_jobs=1
        ).fit(data)
[7]:
def get_model_params_dict_list(data, labels):
    n_data = data.shape[0]
    divs = [12, 15, 20, 30, 45, 65, 90]
    min_cluster_size_steps = [round(n_data / div) for div in divs]
    min_samples_steps = [1, 5, 10]
    return [
        {"model": KMeansModel, "params_list": [{"labels": labels}]},
        {"model": AgglomerativeModel, "params_list": [{"labels": labels}]},
        {"model": HDBSCANModel, "params_list": [{
            "min_cluster_size": min_cluster_size,
            "min_samples": min_samples,
            "labels": labels
        } for min_cluster_size in min_cluster_size_steps for min_samples in min_samples_steps]},
    ]
[8]:
def visualize(df, label_names, dataset_name, data):
    visualize_scores(
        df,
        score_names=["AMI", "ARI"],
        is_higher_score_better=[True, True],
        err_param_name="std",
        title=f"{dataset_name} clustering scores with added noise",
    )
    visualize_best_model_params(
        df,
        model_name="HDBSCAN",
        model_params=["min_cluster_size", "min_samples"],
        score_names=["AMI", "ARI"],
        is_higher_score_better=[True, True],
        err_param_name="std",
        title=f"Best parameters for {dataset_name} clustering"
    )
    visualize_classes(
        df,
        label_names,
        err_param_name="std",
        reduced_data_column="reduced_data",
        labels_column="labels",
        cmap="tab10",
        title=f"{dataset_name} (n={data.shape[0]}) true classes with added noise"
    )
    plt.show()
[9]:
def main():
    data, labels, label_names, dataset_name = get_data()

    df = runner.run(
        train_data=None,
        test_data=data,
        preproc=Preprocessor,
        preproc_params=None,
        err_root_node=get_err_root_node(),
        err_params_list=get_err_params_list(data),
        model_params_dict_list=get_model_params_dict_list(data, labels),
    )

    print_results_by_model(df, ["mean", "min_val", "max_val", "reduced_data", "labels"])
    visualize(df, label_names, dataset_name, data)
[10]:
main()
100%|██████████| 8/8 [16:28<00:00, 123.56s/it]
Agglomerative #1
AMI ARI std time_err time_pre time_mod
0 0.886 0.819 0.000000 6.254 398.776 180.600
1 0.883 0.816 36.428571 6.359 413.132 203.826
2 0.872 0.806 72.857143 6.636 421.233 178.826
3 0.818 0.728 109.285714 6.558 427.936 376.926
4 0.726 0.630 145.714286 6.380 422.852 376.770
5 0.646 0.538 182.142857 6.632 437.173 338.948
6 0.558 0.451 218.571429 9.686 466.863 458.426
7 0.445 0.360 255.000000 6.564 462.590 417.582
HDBSCAN #1
AMI ARI std min_cluster_size min_samples time_err time_pre time_mod
0 0.899 0.850 0.000000 5833.0 1.0 6.254 398.776 2.248
1 0.900 0.850 0.000000 5833.0 5.0 6.254 398.776 2.495
2 0.899 0.850 0.000000 5833.0 10.0 6.254 398.776 2.641
3 0.899 0.850 0.000000 4667.0 1.0 6.254 398.776 2.269
4 0.900 0.850 0.000000 4667.0 5.0 6.254 398.776 2.530
5 0.899 0.850 0.000000 4667.0 10.0 6.254 398.776 2.648
6 0.899 0.850 0.000000 3500.0 1.0 6.254 398.776 2.277
7 0.900 0.850 0.000000 3500.0 5.0 6.254 398.776 2.528
8 0.899 0.850 0.000000 3500.0 10.0 6.254 398.776 2.691
9 0.899 0.850 0.000000 2333.0 1.0 6.254 398.776 2.314
10 0.900 0.850 0.000000 2333.0 5.0 6.254 398.776 2.546
11 0.899 0.850 0.000000 2333.0 10.0 6.254 398.776 2.676
12 0.899 0.850 0.000000 1556.0 1.0 6.254 398.776 2.301
13 0.900 0.850 0.000000 1556.0 5.0 6.254 398.776 2.563
14 0.899 0.850 0.000000 1556.0 10.0 6.254 398.776 2.674
15 0.899 0.850 0.000000 1077.0 1.0 6.254 398.776 2.301
16 0.900 0.850 0.000000 1077.0 5.0 6.254 398.776 2.579
17 0.899 0.850 0.000000 1077.0 10.0 6.254 398.776 2.673
18 0.899 0.850 0.000000 778.0 1.0 6.254 398.776 2.321
19 0.900 0.850 0.000000 778.0 5.0 6.254 398.776 2.557
20 0.899 0.850 0.000000 778.0 10.0 6.254 398.776 2.677
21 0.896 0.848 36.428571 5833.0 1.0 6.359 413.132 2.298
22 0.897 0.848 36.428571 5833.0 5.0 6.359 413.132 2.551
23 0.896 0.848 36.428571 5833.0 10.0 6.359 413.132 2.745
24 0.896 0.848 36.428571 4667.0 1.0 6.359 413.132 2.314
25 0.897 0.848 36.428571 4667.0 5.0 6.359 413.132 2.579
26 0.896 0.848 36.428571 4667.0 10.0 6.359 413.132 2.769
27 0.896 0.848 36.428571 3500.0 1.0 6.359 413.132 2.340
28 0.897 0.848 36.428571 3500.0 5.0 6.359 413.132 2.601
29 0.896 0.848 36.428571 3500.0 10.0 6.359 413.132 2.775
30 0.896 0.848 36.428571 2333.0 1.0 6.359 413.132 2.318
31 0.897 0.848 36.428571 2333.0 5.0 6.359 413.132 2.581
32 0.896 0.848 36.428571 2333.0 10.0 6.359 413.132 2.735
33 0.896 0.848 36.428571 1556.0 1.0 6.359 413.132 2.313
34 0.897 0.848 36.428571 1556.0 5.0 6.359 413.132 2.582
35 0.896 0.848 36.428571 1556.0 10.0 6.359 413.132 2.743
36 0.896 0.848 36.428571 1077.0 1.0 6.359 413.132 2.324
37 0.897 0.848 36.428571 1077.0 5.0 6.359 413.132 2.601
38 0.896 0.848 36.428571 1077.0 10.0 6.359 413.132 2.750
39 0.896 0.848 36.428571 778.0 1.0 6.359 413.132 2.340
40 0.897 0.848 36.428571 778.0 5.0 6.359 413.132 2.623
41 0.896 0.848 36.428571 778.0 10.0 6.359 413.132 2.764
42 0.824 0.653 72.857143 5833.0 1.0 6.636 421.233 2.241
43 0.825 0.653 72.857143 5833.0 5.0 6.636 421.233 2.539
44 0.825 0.653 72.857143 5833.0 10.0 6.636 421.233 2.688
45 0.824 0.653 72.857143 4667.0 1.0 6.636 421.233 2.256
46 0.825 0.653 72.857143 4667.0 5.0 6.636 421.233 2.560
47 0.825 0.653 72.857143 4667.0 10.0 6.636 421.233 2.694
48 0.824 0.653 72.857143 3500.0 1.0 6.636 421.233 2.244
49 0.881 0.835 72.857143 3500.0 5.0 6.636 421.233 2.569
50 0.880 0.834 72.857143 3500.0 10.0 6.636 421.233 2.720
51 0.824 0.653 72.857143 2333.0 1.0 6.636 421.233 2.252
52 0.881 0.835 72.857143 2333.0 5.0 6.636 421.233 2.569
53 0.880 0.834 72.857143 2333.0 10.0 6.636 421.233 2.761
54 0.824 0.653 72.857143 1556.0 1.0 6.636 421.233 2.251
55 0.881 0.835 72.857143 1556.0 5.0 6.636 421.233 2.481
56 0.880 0.834 72.857143 1556.0 10.0 6.636 421.233 2.629
57 0.824 0.653 72.857143 1077.0 1.0 6.636 421.233 2.166
58 0.881 0.835 72.857143 1077.0 5.0 6.636 421.233 2.606
59 0.880 0.834 72.857143 1077.0 10.0 6.636 421.233 2.737
60 0.824 0.653 72.857143 778.0 1.0 6.636 421.233 2.300
61 0.881 0.835 72.857143 778.0 5.0 6.636 421.233 2.605
62 0.880 0.834 72.857143 778.0 10.0 6.636 421.233 2.807
63 0.802 0.640 109.285714 5833.0 1.0 6.558 427.936 2.186
64 0.803 0.640 109.285714 5833.0 5.0 6.558 427.936 2.511
65 0.803 0.640 109.285714 5833.0 10.0 6.558 427.936 2.617
66 0.802 0.640 109.285714 4667.0 1.0 6.558 427.936 2.196
67 0.803 0.640 109.285714 4667.0 5.0 6.558 427.936 2.515
68 0.803 0.640 109.285714 4667.0 10.0 6.558 427.936 2.668
69 0.802 0.640 109.285714 3500.0 1.0 6.558 427.936 2.198
70 0.803 0.640 109.285714 3500.0 5.0 6.558 427.936 2.502
71 0.803 0.640 109.285714 3500.0 10.0 6.558 427.936 2.675
72 0.802 0.640 109.285714 2333.0 1.0 6.558 427.936 2.199
73 0.803 0.640 109.285714 2333.0 5.0 6.558 427.936 2.522
74 0.803 0.640 109.285714 2333.0 10.0 6.558 427.936 2.676
75 0.802 0.640 109.285714 1556.0 1.0 6.558 427.936 2.219
76 0.803 0.640 109.285714 1556.0 5.0 6.558 427.936 2.570
77 0.803 0.640 109.285714 1556.0 10.0 6.558 427.936 2.682
78 0.802 0.640 109.285714 1077.0 1.0 6.558 427.936 2.232
79 0.803 0.640 109.285714 1077.0 5.0 6.558 427.936 2.579
80 0.803 0.640 109.285714 1077.0 10.0 6.558 427.936 2.675
81 0.802 0.640 109.285714 778.0 1.0 6.558 427.936 2.250
82 0.803 0.640 109.285714 778.0 5.0 6.558 427.936 2.573
83 0.803 0.640 109.285714 778.0 10.0 6.558 427.936 2.726
84 0.724 0.519 145.714286 5833.0 1.0 6.380 422.852 2.209
85 0.724 0.519 145.714286 5833.0 5.0 6.380 422.852 2.585
86 0.760 0.611 145.714286 5833.0 10.0 6.380 422.852 2.730
87 0.724 0.519 145.714286 4667.0 1.0 6.380 422.852 2.201
88 0.724 0.519 145.714286 4667.0 5.0 6.380 422.852 2.585
89 0.760 0.611 145.714286 4667.0 10.0 6.380 422.852 2.770
90 0.724 0.519 145.714286 3500.0 1.0 6.380 422.852 2.212
91 0.724 0.519 145.714286 3500.0 5.0 6.380 422.852 2.594
92 0.760 0.611 145.714286 3500.0 10.0 6.380 422.852 2.770
93 0.724 0.519 145.714286 2333.0 1.0 6.380 422.852 2.228
94 0.724 0.519 145.714286 2333.0 5.0 6.380 422.852 2.595
95 0.760 0.611 145.714286 2333.0 10.0 6.380 422.852 2.793
96 0.724 0.519 145.714286 1556.0 1.0 6.380 422.852 2.234
97 0.724 0.519 145.714286 1556.0 5.0 6.380 422.852 2.658
98 0.760 0.611 145.714286 1556.0 10.0 6.380 422.852 2.801
99 0.724 0.519 145.714286 1077.0 1.0 6.380 422.852 2.235
100 0.724 0.519 145.714286 1077.0 5.0 6.380 422.852 2.636
101 0.760 0.611 145.714286 1077.0 10.0 6.380 422.852 2.826
102 0.724 0.519 145.714286 778.0 1.0 6.380 422.852 2.251
103 0.724 0.519 145.714286 778.0 5.0 6.380 422.852 2.614
104 0.760 0.611 145.714286 778.0 10.0 6.380 422.852 2.794
105 0.662 0.489 182.142857 5833.0 1.0 6.632 437.173 2.337
106 0.661 0.488 182.142857 5833.0 5.0 6.632 437.173 2.726
107 0.665 0.483 182.142857 5833.0 10.0 6.632 437.173 2.898
108 0.662 0.489 182.142857 4667.0 1.0 6.632 437.173 2.312
109 0.661 0.488 182.142857 4667.0 5.0 6.632 437.173 2.711
110 0.665 0.483 182.142857 4667.0 10.0 6.632 437.173 2.659
111 0.662 0.489 182.142857 3500.0 1.0 6.632 437.173 2.011
112 0.661 0.488 182.142857 3500.0 5.0 6.632 437.173 2.358
113 0.665 0.483 182.142857 3500.0 10.0 6.632 437.173 2.593
114 0.662 0.489 182.142857 2333.0 1.0 6.632 437.173 2.032
115 0.661 0.488 182.142857 2333.0 5.0 6.632 437.173 2.380
116 0.665 0.483 182.142857 2333.0 10.0 6.632 437.173 2.529
117 0.679 0.548 182.142857 1556.0 1.0 6.632 437.173 2.036
118 0.661 0.488 182.142857 1556.0 5.0 6.632 437.173 2.365
119 0.665 0.483 182.142857 1556.0 10.0 6.632 437.173 2.509
120 0.662 0.489 182.142857 1077.0 1.0 6.632 437.173 2.029
121 0.661 0.488 182.142857 1077.0 5.0 6.632 437.173 2.382
122 0.665 0.483 182.142857 1077.0 10.0 6.632 437.173 2.506
123 0.662 0.489 182.142857 778.0 1.0 6.632 437.173 2.048
124 0.661 0.488 182.142857 778.0 5.0 6.632 437.173 2.364
125 0.665 0.483 182.142857 778.0 10.0 6.632 437.173 2.520
126 0.558 0.425 218.571429 5833.0 1.0 9.686 466.863 2.015
127 0.554 0.420 218.571429 5833.0 5.0 9.686 466.863 2.406
128 0.554 0.423 218.571429 5833.0 10.0 9.686 466.863 2.589
129 0.572 0.439 218.571429 4667.0 1.0 9.686 466.863 1.994
130 0.570 0.438 218.571429 4667.0 5.0 9.686 466.863 2.410
131 0.569 0.437 218.571429 4667.0 10.0 9.686 466.863 2.588
132 0.572 0.439 218.571429 3500.0 1.0 9.686 466.863 2.027
133 0.570 0.438 218.571429 3500.0 5.0 9.686 466.863 2.426
134 0.569 0.437 218.571429 3500.0 10.0 9.686 466.863 2.573
135 0.572 0.439 218.571429 2333.0 1.0 9.686 466.863 2.042
136 0.570 0.438 218.571429 2333.0 5.0 9.686 466.863 2.442
137 0.569 0.437 218.571429 2333.0 10.0 9.686 466.863 2.585
138 0.572 0.439 218.571429 1556.0 1.0 9.686 466.863 2.061
139 0.570 0.438 218.571429 1556.0 5.0 9.686 466.863 2.469
140 0.569 0.437 218.571429 1556.0 10.0 9.686 466.863 2.583
141 0.572 0.439 218.571429 1077.0 1.0 9.686 466.863 2.067
142 0.570 0.438 218.571429 1077.0 5.0 9.686 466.863 2.475
143 0.569 0.437 218.571429 1077.0 10.0 9.686 466.863 2.607
144 0.570 0.450 218.571429 778.0 1.0 9.686 466.863 2.068
145 0.418 0.151 218.571429 778.0 5.0 9.686 466.863 2.443
146 0.569 0.437 218.571429 778.0 10.0 9.686 466.863 2.620
147 0.427 0.300 255.000000 5833.0 1.0 6.564 462.590 1.966
148 0.411 0.280 255.000000 5833.0 5.0 6.564 462.590 2.493
149 0.282 0.098 255.000000 5833.0 10.0 6.564 462.590 2.765
150 0.444 0.316 255.000000 4667.0 1.0 6.564 462.590 2.008
151 0.432 0.302 255.000000 4667.0 5.0 6.564 462.590 2.563
152 0.431 0.305 255.000000 4667.0 10.0 6.564 462.590 2.761
153 0.444 0.316 255.000000 3500.0 1.0 6.564 462.590 2.055
154 0.432 0.302 255.000000 3500.0 5.0 6.564 462.590 2.568
155 0.431 0.305 255.000000 3500.0 10.0 6.564 462.590 2.762
156 0.452 0.321 255.000000 2333.0 1.0 6.564 462.590 2.059
157 0.432 0.302 255.000000 2333.0 5.0 6.564 462.590 2.561
158 0.431 0.305 255.000000 2333.0 10.0 6.564 462.590 2.777
159 0.172 0.036 255.000000 1556.0 1.0 6.564 462.590 2.018
160 0.449 0.311 255.000000 1556.0 5.0 6.564 462.590 2.571
161 0.433 0.314 255.000000 1556.0 10.0 6.564 462.590 2.784
162 0.172 0.036 255.000000 1077.0 1.0 6.564 462.590 2.016
163 0.449 0.311 255.000000 1077.0 5.0 6.564 462.590 2.573
164 0.433 0.314 255.000000 1077.0 10.0 6.564 462.590 2.763
165 0.172 0.036 255.000000 778.0 1.0 6.564 462.590 2.005
166 0.449 0.311 255.000000 778.0 5.0 6.564 462.590 2.553
167 0.433 0.314 255.000000 778.0 10.0 6.564 462.590 2.734
KMeans #1
AMI ARI std time_err time_pre time_mod
0 0.882 0.813 0.000000 6.254 398.776 1.225
1 0.878 0.809 36.428571 6.359 413.132 1.282
2 0.859 0.788 72.857143 6.636 421.233 1.285
3 0.810 0.727 109.285714 6.558 427.936 1.473
4 0.725 0.627 145.714286 6.380 422.852 1.563
5 0.649 0.555 182.142857 6.632 437.173 1.785
6 0.550 0.463 218.571429 9.686 466.863 3.467
7 0.455 0.371 255.000000 6.564 462.590 3.420
../_images/case_studies_Image_Clustering_Added_Noise_20_7.png
../_images/case_studies_Image_Clustering_Added_Noise_20_8.png
../_images/case_studies_Image_Clustering_Added_Noise_20_9.png

The notebook for this case study can be found here.