aliabd commited on
Commit
aa31f21
·
1 Parent(s): 4da04a0
Files changed (2) hide show
  1. app.py +292 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # URL: https://huggingface.co/spaces/gradio/clustering
2
+ # imports
3
+ import gradio as gr
4
+ import math
5
+ from functools import partial
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ from sklearn.cluster import (
9
+ AgglomerativeClustering, Birch, DBSCAN, KMeans, MeanShift, OPTICS, SpectralClustering, estimate_bandwidth
10
+ )
11
+ from sklearn.datasets import make_blobs, make_circles, make_moons
12
+ from sklearn.mixture import GaussianMixture
13
+ from sklearn.neighbors import kneighbors_graph
14
+ from sklearn.preprocessing import StandardScaler
15
+
16
+ # loading models and setting up
17
+ plt.style.use('seaborn')
18
+ SEED = 0
19
+ MAX_CLUSTERS = 10
20
+ N_SAMPLES = 1000
21
+ N_COLS = 3
22
+ FIGSIZE = 7, 7 # does not affect size in webpage
23
+ COLORS = [
24
+ 'blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'
25
+ ]
26
+ assert len(COLORS) >= MAX_CLUSTERS, "Not enough different colors for all clusters"
27
+ np.random.seed(SEED)
28
+
29
+ # defining core fns
30
+
31
+ def normalize(X):
32
+ return StandardScaler().fit_transform(X)
33
+
34
+
35
+ def get_regular(n_clusters):
36
+ # spiral pattern
37
+ centers = [
38
+ [0, 0],
39
+ [1, 0],
40
+ [1, 1],
41
+ [0, 1],
42
+ [-1, 1],
43
+ [-1, 0],
44
+ [-1, -1],
45
+ [0, -1],
46
+ [1, -1],
47
+ [2, -1],
48
+ ][:n_clusters]
49
+ assert len(centers) == n_clusters
50
+ X, labels = make_blobs(n_samples=N_SAMPLES, centers=centers, cluster_std=0.25, random_state=SEED)
51
+ return normalize(X), labels
52
+
53
+
54
+ def get_circles(n_clusters):
55
+ X, labels = make_circles(n_samples=N_SAMPLES, factor=0.5, noise=0.05, random_state=SEED)
56
+ return normalize(X), labels
57
+
58
+
59
+ def get_moons(n_clusters):
60
+ X, labels = make_moons(n_samples=N_SAMPLES, noise=0.05, random_state=SEED)
61
+ return normalize(X), labels
62
+
63
+
64
+ def get_noise(n_clusters):
65
+ np.random.seed(SEED)
66
+ X, labels = np.random.rand(N_SAMPLES, 2), np.random.randint(0, n_clusters, size=(N_SAMPLES,))
67
+ return normalize(X), labels
68
+
69
+
70
+ def get_anisotropic(n_clusters):
71
+ X, labels = make_blobs(n_samples=N_SAMPLES, centers=n_clusters, random_state=170)
72
+ transformation = [[0.6, -0.6], [-0.4, 0.8]]
73
+ X = np.dot(X, transformation)
74
+ return X, labels
75
+
76
+
77
+ def get_varied(n_clusters):
78
+ cluster_std = [1.0, 2.5, 0.5, 1.0, 2.5, 0.5, 1.0, 2.5, 0.5, 1.0][:n_clusters]
79
+ assert len(cluster_std) == n_clusters
80
+ X, labels = make_blobs(
81
+ n_samples=N_SAMPLES, centers=n_clusters, cluster_std=cluster_std, random_state=SEED
82
+ )
83
+ return normalize(X), labels
84
+
85
+
86
+ def get_spiral(n_clusters):
87
+ # from https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_clustering.html
88
+ np.random.seed(SEED)
89
+ t = 1.5 * np.pi * (1 + 3 * np.random.rand(1, N_SAMPLES))
90
+ x = t * np.cos(t)
91
+ y = t * np.sin(t)
92
+ X = np.concatenate((x, y))
93
+ X += 0.7 * np.random.randn(2, N_SAMPLES)
94
+ X = np.ascontiguousarray(X.T)
95
+
96
+ labels = np.zeros(N_SAMPLES, dtype=int)
97
+ return normalize(X), labels
98
+
99
+
100
+ DATA_MAPPING = {
101
+ 'regular': get_regular,
102
+ 'circles': get_circles,
103
+ 'moons': get_moons,
104
+ 'spiral': get_spiral,
105
+ 'noise': get_noise,
106
+ 'anisotropic': get_anisotropic,
107
+ 'varied': get_varied,
108
+ }
109
+
110
+
111
+ def get_groundtruth_model(X, labels, n_clusters, **kwargs):
112
+ # dummy model to show true label distribution
113
+ class Dummy:
114
+ def __init__(self, y):
115
+ self.labels_ = labels
116
+
117
+ return Dummy(labels)
118
+
119
+
120
+ def get_kmeans(X, labels, n_clusters, **kwargs):
121
+ model = KMeans(init="k-means++", n_clusters=n_clusters, n_init=10, random_state=SEED)
122
+ model.set_params(**kwargs)
123
+ return model.fit(X)
124
+
125
+
126
+ def get_dbscan(X, labels, n_clusters, **kwargs):
127
+ model = DBSCAN(eps=0.3)
128
+ model.set_params(**kwargs)
129
+ return model.fit(X)
130
+
131
+
132
+ def get_agglomerative(X, labels, n_clusters, **kwargs):
133
+ connectivity = kneighbors_graph(
134
+ X, n_neighbors=n_clusters, include_self=False
135
+ )
136
+ # make connectivity symmetric
137
+ connectivity = 0.5 * (connectivity + connectivity.T)
138
+ model = AgglomerativeClustering(
139
+ n_clusters=n_clusters, linkage="ward", connectivity=connectivity
140
+ )
141
+ model.set_params(**kwargs)
142
+ return model.fit(X)
143
+
144
+
145
+ def get_meanshift(X, labels, n_clusters, **kwargs):
146
+ bandwidth = estimate_bandwidth(X, quantile=0.25)
147
+ model = MeanShift(bandwidth=bandwidth, bin_seeding=True)
148
+ model.set_params(**kwargs)
149
+ return model.fit(X)
150
+
151
+
152
+ def get_spectral(X, labels, n_clusters, **kwargs):
153
+ model = SpectralClustering(
154
+ n_clusters=n_clusters,
155
+ eigen_solver="arpack",
156
+ affinity="nearest_neighbors",
157
+ )
158
+ model.set_params(**kwargs)
159
+ return model.fit(X)
160
+
161
+
162
+ def get_optics(X, labels, n_clusters, **kwargs):
163
+ model = OPTICS(
164
+ min_samples=7,
165
+ xi=0.05,
166
+ min_cluster_size=0.1,
167
+ )
168
+ model.set_params(**kwargs)
169
+ return model.fit(X)
170
+
171
+
172
+ def get_birch(X, labels, n_clusters, **kwargs):
173
+ model = Birch(n_clusters=n_clusters)
174
+ model.set_params(**kwargs)
175
+ return model.fit(X)
176
+
177
+
178
+ def get_gaussianmixture(X, labels, n_clusters, **kwargs):
179
+ model = GaussianMixture(
180
+ n_components=n_clusters, covariance_type="full", random_state=SEED,
181
+ )
182
+ model.set_params(**kwargs)
183
+ return model.fit(X)
184
+
185
+
186
+ MODEL_MAPPING = {
187
+ 'True labels': get_groundtruth_model,
188
+ 'KMeans': get_kmeans,
189
+ 'DBSCAN': get_dbscan,
190
+ 'MeanShift': get_meanshift,
191
+ 'SpectralClustering': get_spectral,
192
+ 'OPTICS': get_optics,
193
+ 'Birch': get_birch,
194
+ 'GaussianMixture': get_gaussianmixture,
195
+ 'AgglomerativeClustering': get_agglomerative,
196
+ }
197
+
198
+
199
+ def plot_clusters(ax, X, labels):
200
+ set_clusters = set(labels)
201
+ set_clusters.discard(-1) # -1 signifiies outliers, which we plot separately
202
+ for label, color in zip(sorted(set_clusters), COLORS):
203
+ idx = labels == label
204
+ if not sum(idx):
205
+ continue
206
+ ax.scatter(X[idx, 0], X[idx, 1], color=color)
207
+
208
+ # show outliers (if any)
209
+ idx = labels == -1
210
+ if sum(idx):
211
+ ax.scatter(X[idx, 0], X[idx, 1], c='k', marker='x')
212
+
213
+ ax.grid(None)
214
+ ax.set_xticks([])
215
+ ax.set_yticks([])
216
+ return ax
217
+
218
+
219
+ def cluster(dataset: str, n_clusters: int, clustering_algorithm: str):
220
+ if isinstance(n_clusters, dict):
221
+ n_clusters = n_clusters['value']
222
+ else:
223
+ n_clusters = int(n_clusters)
224
+
225
+ X, labels = DATA_MAPPING[dataset](n_clusters)
226
+ model = MODEL_MAPPING[clustering_algorithm](X, labels, n_clusters=n_clusters)
227
+ if hasattr(model, "labels_"):
228
+ y_pred = model.labels_.astype(int)
229
+ else:
230
+ y_pred = model.predict(X)
231
+
232
+ fig, ax = plt.subplots(figsize=FIGSIZE)
233
+
234
+ plot_clusters(ax, X, y_pred)
235
+ ax.set_title(clustering_algorithm, fontsize=16)
236
+
237
+ return fig
238
+
239
+
240
+ title = "Clustering with Scikit-learn"
241
+ description = (
242
+ "This example shows how different clustering algorithms work. Simply pick "
243
+ "the dataset and the number of clusters to see how the clustering algorithms work. "
244
+ "Colored cirles are (predicted) labels and black x are outliers."
245
+ )
246
+
247
+
248
+ def iter_grid(n_rows, n_cols):
249
+ # create a grid using gradio Block
250
+ for _ in range(n_rows):
251
+ with gr.Row():
252
+ for _ in range(n_cols):
253
+ with gr.Column():
254
+ yield
255
+
256
+ # starting a block
257
+
258
+ with gr.Blocks(title=title) as demo:
259
+ # adding text as HTML and Markdown
260
+ gr.HTML(f"<b>{title}</b>")
261
+ gr.Markdown(description)
262
+
263
+ # setting up the inputs
264
+ input_models = list(MODEL_MAPPING)
265
+ input_data = gr.Radio(
266
+ list(DATA_MAPPING),
267
+ value="regular",
268
+ label="dataset"
269
+ )
270
+ input_n_clusters = gr.Slider(
271
+ minimum=1,
272
+ maximum=MAX_CLUSTERS,
273
+ value=4,
274
+ step=1,
275
+ label='Number of clusters'
276
+ )
277
+ n_rows = int(math.ceil(len(input_models) / N_COLS))
278
+ counter = 0
279
+ for _ in iter_grid(n_rows, N_COLS):
280
+ if counter >= len(input_models):
281
+ break
282
+
283
+ input_model = input_models[counter]
284
+ # defining the output
285
+ plot = gr.Plot(label=input_model)
286
+ fn = partial(cluster, clustering_algorithm=input_model)
287
+ input_data.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)
288
+ input_n_clusters.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)
289
+ counter += 1
290
+
291
+ # launch
292
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ matplotlib>=3.5.2
2
+ scikit-learn>=1.0.1