Christina Theodoris commited on
Commit
ff551ee
·
1 Parent(s): 13cd541

plot umap for all labels in same view

Browse files
Files changed (1) hide show
  1. geneformer/emb_extractor.py +29 -23
geneformer/emb_extractor.py CHANGED
@@ -278,14 +278,18 @@ def label_gene_embs(embs, downsampled_data, token_gene_dict):
278
  return embs_df
279
 
280
 
281
- def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0):
282
  only_embs_df = embs_df.iloc[:, :emb_dims]
283
  only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
284
  only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
285
  str
286
  )
287
  vars_dict = {"embs": only_embs_df.columns}
288
- obs_dict = {"cell_id": list(only_embs_df.index), f"{label}": list(embs_df[label])}
 
 
 
 
289
  adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
290
  sc.tl.pca(adata, svd_solver="arpack")
291
  sc.pp.neighbors(adata, random_state=seed)
@@ -296,21 +300,26 @@ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0):
296
  if kwargs_dict is not None:
297
  default_kwargs_dict.update(kwargs_dict)
298
 
299
- cats = set(embs_df[label])
300
-
301
- with plt.rc_context():
302
- ax = sc.pl.umap(adata, color=label, show=False, **default_kwargs_dict)
303
- ax.legend(
304
- markerscale=2,
305
- frameon=False,
306
- loc="center left",
307
- bbox_to_anchor=(1, 0.5),
308
- ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3),
309
- )
310
- plt.show()
311
- plt.savefig(output_file, bbox_inches="tight")
312
-
313
-
 
 
 
 
 
314
  def gen_heatmap_class_colors(labels, df):
315
  pal = sns.cubehelix_palette(
316
  len(Counter(labels).keys()),
@@ -856,12 +865,9 @@ class EmbExtractor:
856
  f"Label {label} from labels_to_plot "
857
  f"not present in provided embeddings dataframe."
858
  )
859
- continue
860
- output_prefix_label = output_prefix + f"_umap_{label}"
861
- output_file = (
862
- Path(output_directory) / output_prefix_label
863
- ).with_suffix(".pdf")
864
- plot_umap(embs, emb_dims, label, output_file, kwargs_dict)
865
 
866
  if plot_style == "heatmap":
867
  for label in self.labels_to_plot:
 
278
  return embs_df
279
 
280
 
281
+ def plot_umap(embs_df, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict, seed=0):
282
  only_embs_df = embs_df.iloc[:, :emb_dims]
283
  only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
284
  only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
285
  str
286
  )
287
  vars_dict = {"embs": only_embs_df.columns}
288
+
289
+ obs_dict = {"cell_id": list(only_embs_df.index)}
290
+ for label_i in labels_clean:
291
+ obs_dict[label_i] = list(embs_df[label_i])
292
+
293
  adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
294
  sc.tl.pca(adata, svd_solver="arpack")
295
  sc.pp.neighbors(adata, random_state=seed)
 
300
  if kwargs_dict is not None:
301
  default_kwargs_dict.update(kwargs_dict)
302
 
303
+ for label_i in labels_clean:
304
+ output_prefix_label = output_prefix + f"_umap_{label_i}"
305
+ output_file = (
306
+ Path(output_directory) / output_prefix_label
307
+ ).with_suffix(".pdf")
308
+
309
+ cats = set(embs_df[label_i])
310
+
311
+ with plt.rc_context():
312
+ ax = sc.pl.umap(adata, color=label_i, show=False, **default_kwargs_dict)
313
+ ax.legend(
314
+ markerscale=2,
315
+ frameon=False,
316
+ loc="center left",
317
+ bbox_to_anchor=(1, 0.5),
318
+ ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3),
319
+ )
320
+ plt.show()
321
+ plt.savefig(output_file, bbox_inches="tight")
322
+
323
  def gen_heatmap_class_colors(labels, df):
324
  pal = sns.cubehelix_palette(
325
  len(Counter(labels).keys()),
 
865
  f"Label {label} from labels_to_plot "
866
  f"not present in provided embeddings dataframe."
867
  )
868
+
869
+ labels_clean = [label for label in self.labels_to_plot if label in emb_labels]
870
+ plot_umap(embs, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict)
 
 
 
871
 
872
  if plot_style == "heatmap":
873
  for label in self.labels_to_plot: