Christina Theodoris
commited on
Commit
·
ff551ee
1
Parent(s):
13cd541
plot umap for all labels in same view
Browse files- geneformer/emb_extractor.py +29 -23
geneformer/emb_extractor.py
CHANGED
@@ -278,14 +278,18 @@ def label_gene_embs(embs, downsampled_data, token_gene_dict):
|
|
278 |
return embs_df
|
279 |
|
280 |
|
281 |
-
def plot_umap(embs_df, emb_dims,
|
282 |
only_embs_df = embs_df.iloc[:, :emb_dims]
|
283 |
only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
|
284 |
only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
|
285 |
str
|
286 |
)
|
287 |
vars_dict = {"embs": only_embs_df.columns}
|
288 |
-
|
|
|
|
|
|
|
|
|
289 |
adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
|
290 |
sc.tl.pca(adata, svd_solver="arpack")
|
291 |
sc.pp.neighbors(adata, random_state=seed)
|
@@ -296,21 +300,26 @@ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0):
|
|
296 |
if kwargs_dict is not None:
|
297 |
default_kwargs_dict.update(kwargs_dict)
|
298 |
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
|
|
|
|
|
|
|
|
|
|
314 |
def gen_heatmap_class_colors(labels, df):
|
315 |
pal = sns.cubehelix_palette(
|
316 |
len(Counter(labels).keys()),
|
@@ -856,12 +865,9 @@ class EmbExtractor:
|
|
856 |
f"Label {label} from labels_to_plot "
|
857 |
f"not present in provided embeddings dataframe."
|
858 |
)
|
859 |
-
|
860 |
-
|
861 |
-
|
862 |
-
Path(output_directory) / output_prefix_label
|
863 |
-
).with_suffix(".pdf")
|
864 |
-
plot_umap(embs, emb_dims, label, output_file, kwargs_dict)
|
865 |
|
866 |
if plot_style == "heatmap":
|
867 |
for label in self.labels_to_plot:
|
|
|
278 |
return embs_df
|
279 |
|
280 |
|
281 |
+
def plot_umap(embs_df, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict, seed=0):
|
282 |
only_embs_df = embs_df.iloc[:, :emb_dims]
|
283 |
only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
|
284 |
only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
|
285 |
str
|
286 |
)
|
287 |
vars_dict = {"embs": only_embs_df.columns}
|
288 |
+
|
289 |
+
obs_dict = {"cell_id": list(only_embs_df.index)}
|
290 |
+
for label_i in labels_clean:
|
291 |
+
obs_dict[label_i] = list(embs_df[label_i])
|
292 |
+
|
293 |
adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
|
294 |
sc.tl.pca(adata, svd_solver="arpack")
|
295 |
sc.pp.neighbors(adata, random_state=seed)
|
|
|
300 |
if kwargs_dict is not None:
|
301 |
default_kwargs_dict.update(kwargs_dict)
|
302 |
|
303 |
+
for label_i in labels_clean:
|
304 |
+
output_prefix_label = output_prefix + f"_umap_{label_i}"
|
305 |
+
output_file = (
|
306 |
+
Path(output_directory) / output_prefix_label
|
307 |
+
).with_suffix(".pdf")
|
308 |
+
|
309 |
+
cats = set(embs_df[label_i])
|
310 |
+
|
311 |
+
with plt.rc_context():
|
312 |
+
ax = sc.pl.umap(adata, color=label_i, show=False, **default_kwargs_dict)
|
313 |
+
ax.legend(
|
314 |
+
markerscale=2,
|
315 |
+
frameon=False,
|
316 |
+
loc="center left",
|
317 |
+
bbox_to_anchor=(1, 0.5),
|
318 |
+
ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3),
|
319 |
+
)
|
320 |
+
plt.show()
|
321 |
+
plt.savefig(output_file, bbox_inches="tight")
|
322 |
+
|
323 |
def gen_heatmap_class_colors(labels, df):
|
324 |
pal = sns.cubehelix_palette(
|
325 |
len(Counter(labels).keys()),
|
|
|
865 |
f"Label {label} from labels_to_plot "
|
866 |
f"not present in provided embeddings dataframe."
|
867 |
)
|
868 |
+
|
869 |
+
labels_clean = [label for label in self.labels_to_plot if label in emb_labels]
|
870 |
+
plot_umap(embs, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict)
|
|
|
|
|
|
|
871 |
|
872 |
if plot_style == "heatmap":
|
873 |
for label in self.labels_to_plot:
|