KashyapiNagaHarshitha commited on
Commit
1600433
·
verified ·
1 Parent(s): 3995675

Upload my_modules.py

Browse files
Files changed (1) hide show
  1. my_modules.py +468 -0
my_modules.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import subprocess
5
+ import os
6
+ import random
7
+ import re
8
+ import pandas as pd
9
+ import numpy as np
10
+ import seaborn as sb
11
+ import matplotlib.pyplot as plt
12
+ import matplotlib.colors as mplc
13
+ import subprocess
14
+
15
+
16
+ from scipy import signal
17
+
18
+ import plotly.figure_factory as ff
19
+ import plotly
20
+ import plotly.graph_objs as go
21
+ from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
22
+
23
+
24
+ # This function takes in a dataframe, changes the names
25
+ # of the column in various ways, and returns the dataframe.
26
+ # For best accuracy and generalizability, the code uses
27
+ # regular expressions (regex) to find strings for replacement.
28
+ def apply_header_changes(df):
29
+ # remove lowercase x at beginning of name
30
+ df.columns = df.columns.str.replace("^x","")
31
+ # remove space at beginning of name
32
+ df.columns = df.columns.str.replace("^ ","")
33
+ # replace space with underscore
34
+ df.columns = df.columns.str.replace(" ","_")
35
+ # fix typos
36
+ df.columns = df.columns.str.replace("AF_AF","AF")
37
+ # change "Cell Id" into "ID"
38
+ df.columns = df.columns.str.replace("Cell Id","ID")
39
+ # if the ID is the index, change "Cell Id" into "ID"
40
+ df.index.name = "ID"
41
+ #
42
+ df.columns = df.columns.str.replace("","")
43
+ return df
44
+
45
+ def apply_df_changes(df):
46
+ # Remove "@1" after the ID in the index
47
+ df.index = df.index.str.replace(r'@1$', '')
48
+ return df
49
+
50
+ def compare_headers(expected, actual, name):
51
+ missing_actual = np.setdiff1d(expected, actual)
52
+ extra_actual = np.setdiff1d(actual, expected)
53
+ if len(missing_actual) > 0:
54
+ #print("WARNING: File '" + name + "' lacks the following expected header(s) after import header reformatting: \n"
55
+ # + str(missing_actual))
56
+ print("WARNING: File '" + name + "' lacks the following expected item(s): \n" + str(missing_actual))
57
+ if len(extra_actual) > 0:
58
+ #print("WARNING: '" + name + "' has the following unexpected header(s) after import header reformatting: \n"
59
+ # + str(extra_actual))
60
+ print("WARNING: '" + name + "' has the following unexpected item(s): \n" + str(extra_actual))
61
+
62
+ return None
63
+
64
+
65
+ def add_metadata_location(row):
66
+ fc = row['full_column'].lower()
67
+ if 'cytoplasm' in fc and 'cell' not in fc and 'nucleus' not in fc:
68
+ return 'cytoplasm'
69
+ elif 'cell' in fc and 'cytoplasm' not in fc and 'nucleus' not in fc:
70
+ return 'cell'
71
+ elif 'nucleus' in fc and 'cell' not in fc and 'cytoplasm' not in fc:
72
+ return 'nucleus'
73
+ else:
74
+ return 'unknown'
75
+
76
+
77
+ def get_perc(row, cell_type):
78
+ total = row['stroma'] + row['immune'] + row['cancer']+row['endothelial']
79
+ return round(row[cell_type]/total *100,1)
80
+
81
+
82
+
83
+ # Divide each marker (and its localisation) by the right exposure setting for each group of samples
84
+ def divide_exp_time(col, exp_col, metadata):
85
+ exp_time = metadata.loc[metadata['full_column'] == col.name, exp_col].values[0]
86
+ return col/exp_time
87
+
88
+
89
+ def do_background_sub(col, df, metadata):
90
+ #print(col.name)
91
+ location = metadata.loc[metadata['full_column'] == col.name, 'localisation'].values[0]
92
+ #print('location = ' + location)
93
+ channel = metadata.loc[metadata['full_column'] == col.name, 'Channel'].values[0]
94
+ #print('channel = ' + channel)
95
+ af_target = metadata.loc[
96
+ (metadata['Channel']==channel) \
97
+ & (metadata['localisation']==location) \
98
+ & (metadata['target_lower'].str.contains(r'^af\d{3}$')),\
99
+ 'full_column'].values[0]
100
+ return col - df.loc[:,af_target]
101
+
102
+
103
+ """
104
+ This function plots distributions. It takes in a string title (title), a list of
105
+ dataframes from which to plot (dfs), a list of dataframe names for the legend
106
+ (names), a list of the desired colors for the plotted samples (colors),
107
+ a string for the x-axis label (x_label), ```a float binwidth for histrogram (bin_size)```,
108
+ a boolean to show the legend or not (legend),
109
+ and the names of the marker(s) to plot (input_labels). If not specified,
110
+ the function will plot all markers in one plot. input_labels can either be a
111
+ single string, e.g., 'my_marker', or a list, e.g., ['my_marker1','my_marker2'].
112
+
113
+ The function will create a distribution plot and save it to png. It requires
114
+ a list of items not to be considered as markers when evaluating column names
115
+ (not_markers) to be in memory. It also requires a desired output location of
116
+ the files (output_dir) to already be in memory.
117
+ """
118
+
119
+
120
+
121
+ def make_distr_plot_per_sample(title, location, dfs, df_names, colors, x_label, legend, xlims = None, markers = ['all'],not_intensities = None):
122
+ ### GET LIST OF MARKERS TO PLOT ###
123
+ # Get list of markers to plot if not specified by user, using columns in first df
124
+ # Writing function(parameter = FILLER) makes that parameter optional when user calls function,
125
+ # since it is given a default value!
126
+ if markers == ["all"]:
127
+ markers = [c for c in dfs[0].columns.values if c not in not_intensities]
128
+ elif not isinstance(markers, list):
129
+ markers = [markers]
130
+ # Make input labels a set to get only unique values, then put back into list
131
+ markers = list(set(markers))
132
+
133
+ ### GET XLIMS ###
134
+ if xlims == None:
135
+ mins = [df.loc[:,markers].min().min() for df in dfs]
136
+ maxes = [df.loc[:,markers].max().max() for df in dfs]
137
+ xlims = [min(mins), max(maxes)]
138
+ if not isinstance(xlims, list):
139
+ print("Problem - xlmis not list. Exiting method...")
140
+ return None
141
+ ### CHECK DATA CAN BE PLOTTED ###
142
+ # Check for data with only 1 unique value - this will cause error if plotted
143
+ group_labels = []
144
+ hist_data = []
145
+ # Iterate through all dataframes (dfs)
146
+ for i in range(len(dfs)):
147
+ # Iterate through all marker labels
148
+ for f in markers:
149
+ # If there is only one unique value in the marker data for this dataframe,
150
+ # you cannot plot a distribution plot. It gives you a linear algebra
151
+ # singular value matrix error
152
+ if dfs[i][f].nunique() != 1:
153
+ # Add df name and marker name to labels list
154
+ # If we have >1 df, we want to make clear
155
+ # which legend label is associated with which df
156
+ if len(df_names) > 1:
157
+ group_labels.append(df_names[i]+"_"+f)
158
+ else:
159
+ group_labels.append(f)
160
+ # add the data to the data list
161
+ hist_data.append(dfs[i][f])
162
+ # if no data had >1 unique values, there is nothing to plot
163
+ if len(group_labels) < 1:
164
+ print("No markers plotted - all were singular value. Names and markers were " + str(df_names) + ", " + str(markers))
165
+ return None
166
+
167
+ ### TRANSFORM COLOR ITEMS TO CORRECT TYPE ###
168
+ if isinstance(colors[0], tuple):
169
+ colors = ['rgb' + str(color) for color in colors]
170
+
171
+ ### PLOT DATA ###
172
+ # Create plot
173
+ fig = ff.create_distplot(hist_data, group_labels, bin_size=0.1,
174
+ #colors=colors, bin_size=bin_size, show_rug=False)#show_hist=False,
175
+ colors=colors, show_rug=False)
176
+ # Adjust title, font, background color, legend...
177
+ fig.update_layout(title_text=title, font=dict(size=18),
178
+ plot_bgcolor = 'white', showlegend = legend)#, legend_x = 3)
179
+ # Adjust opacity
180
+ fig.update_traces(opacity=0.6)
181
+ # Adjust x-axis parameters
182
+ fig.update_xaxes(title_text = x_label, showline=True, linewidth=2, linecolor='black',
183
+ tickfont=dict(size=18), range = xlims) # x lims was here
184
+ # Adjust y-axis parameters
185
+ fig.update_yaxes(title_text = "Kernel density estimate",showline=True, linewidth=1, linecolor='black',
186
+ tickfont=dict(size=18))
187
+
188
+
189
+ ### SAVE/DISPLAY PLOT ###
190
+ # Save plot to HTML
191
+ # plotly.io.write_html(fig, file = output_dir + "/" + title + ".html")
192
+ # Plot in new tab
193
+ #plot(fig)
194
+ # Save to png
195
+ filename = os.path.join(location, title.replace(" ","_") + ".png")
196
+ fig.write_image(filename)
197
+ return None
198
+
199
+
200
+
201
+
202
+
203
+ # this could be changed to use recursion and make it 'smarter'
204
+
205
+ def shorten_feature_names(long_names):
206
+ name_dict = dict(zip(long_names,[n.split('_')[0] for n in long_names]))
207
+ names_lts, long_names, iteration = shorten_feature_names_helper(name_dict, long_names, 1)
208
+ # names_lts = names long-to-short
209
+ # names_stl = names stl
210
+ names_stl = {}
211
+ for n in names_lts.items():
212
+ names_stl[n[1]] = n[0]
213
+ return names_lts, names_stl
214
+
215
+
216
+ def shorten_feature_names_helper(name_dict, long_names, iteration):
217
+ #print("\nThis is iteration #"+str(iteration))
218
+ #print("name_dict is: " + str(name_dict))
219
+ #print("long_names is: " + str(long_names))
220
+ ## If the number of unique nicknames == number of long names
221
+ ## then the work here is done
222
+ #print('\nCompare lengths: ' + str(len(set(name_dict.values()))) + ", " + str(len(long_names)))
223
+ #print('set(name_dict.values()): ' + str(set(name_dict.values())))
224
+ #print('long_names: ' + str(long_names))
225
+ if len(set(name_dict.values())) == len(long_names):
226
+ #print('All done!')
227
+ return name_dict, long_names, iteration
228
+
229
+ ## otherwise, if the number of unique nicknames is not
230
+ ## equal to the number of long names (must be shorter than),
231
+ ## then we need to find more unique names
232
+ iteration += 1
233
+ nicknames_set = set()
234
+ non_unique_nicknames = set()
235
+ # construct set of current nicknames
236
+ for long_name in long_names:
237
+ #print('long_name is ' + long_name + ' and non_unique_nicknames set is ' + str(non_unique_nicknames))
238
+ short_name = name_dict[long_name]
239
+ if short_name in nicknames_set:
240
+ non_unique_nicknames.add(short_name)
241
+ else:
242
+ nicknames_set.add(short_name)
243
+ #print('non_unique_nicknames are: ' + str(non_unique_nicknames))
244
+
245
+ # figure out all long names associated
246
+ # with the non-unique short names
247
+ trouble_long_names = set()
248
+ for long_name in long_names:
249
+ short_name = name_dict[long_name]
250
+ if short_name in non_unique_nicknames:
251
+ trouble_long_names.add(long_name)
252
+
253
+ #print('troublesome long names are: ' + str(trouble_long_names))
254
+ #print('name_dict: ' + str(name_dict))
255
+ # operate on all names that are associated with
256
+ # the non-unique short nicknames
257
+ for long_name in trouble_long_names:
258
+ #print('trouble long name is: ' + long_name)
259
+ #print('old nickname is: ' + name_dict[long_name])
260
+ name_dict[long_name] = '_'.join(long_name.split('_')[0:iteration])
261
+ #print('new nickname is: ' + name_dict[long_name])
262
+ shorten_feature_names_helper(name_dict, long_names, iteration)
263
+ return name_dict, long_names, iteration
264
+
265
+
266
+ def heatmap_function2(title,
267
+ data,
268
+ method, metric, cmap,
269
+ cbar_kws, xticklabels, save_loc,
270
+ row_cluster, col_cluster,
271
+ annotations = {'rows':[],'cols':[]}):
272
+
273
+ sb.set(font_scale= 6.0)
274
+
275
+ # Extract row and column mappings
276
+ row_mappings = []
277
+ col_mappings = []
278
+ for ann in annotations['rows']:
279
+ row_mappings.append(ann['mapping'])
280
+ for ann in annotations['cols']:
281
+ col_mappings.append(ann['mapping'])
282
+ # If empty lists, convert to None so seaborn accepts
283
+ # as the row_colors or col_colors objects
284
+ if len(row_mappings) == 0:
285
+ row_mappings = None
286
+ if len(col_mappings) == 0:
287
+ col_mappings = None
288
+
289
+ def heatmap_function(title,
290
+ data,
291
+ method, metric, cmap,
292
+ cbar_kws, xticklabels, save_loc,
293
+ row_cluster, col_cluster,
294
+ annotations = {'rows':[],'cols':[]}):
295
+
296
+ sb.set(font_scale= 2.0)
297
+
298
+ # Extract row and column mappings
299
+ row_mappings = []
300
+ col_mappings = []
301
+ for ann in annotations['rows']:
302
+ row_mappings.append(ann['mapping'])
303
+ for ann in annotations['cols']:
304
+ col_mappings.append(ann['mapping'])
305
+ # If empty lists, convert to None so seaborn accepts
306
+ # as the row_colors or col_colors objects
307
+ if len(row_mappings) == 0:
308
+ row_mappings = None
309
+ if len(col_mappings) == 0:
310
+ col_mappings = None
311
+
312
+ # Create clustermap
313
+ g = sb.clustermap(data = data,
314
+ robust = True,
315
+ method = method, metric = metric,
316
+ cmap = cmap,
317
+ row_cluster = row_cluster, col_cluster = col_cluster,
318
+ figsize = (40,30),
319
+ row_colors=row_mappings, col_colors=col_mappings,
320
+ yticklabels = False,
321
+ cbar_kws = cbar_kws,
322
+ xticklabels = xticklabels)
323
+
324
+ # To rotate slightly the x labels
325
+ plt.setp(g.ax_heatmap.xaxis.get_majorticklabels(), rotation=45)
326
+
327
+ # Add title
328
+ g.fig.suptitle(title, fontsize = 60.0)
329
+
330
+ #And now for the legends:
331
+ # iterate through 'rows', 'cols'
332
+ for ann_type in annotations.keys():
333
+ # iterate through each individual annotation feature
334
+ for ann in annotations[ann_type]:
335
+ color_dict = ann['dict']
336
+ handles = []
337
+ for item in color_dict.keys():
338
+ h = g.ax_col_dendrogram.bar(0,0, color = color_dict[item], label = item,
339
+ linewidth = 0)
340
+ handles.append(h)
341
+ legend = plt.legend(handles = handles, loc = ann['location'], title = ann['label'],
342
+ bbox_to_anchor=ann['bbox_to_anchor'],
343
+ bbox_transform=plt.gcf().transFigure)
344
+ ax = plt.gca().add_artist(legend)
345
+
346
+ # Save image
347
+ filename = os.path.join(save_loc, title.lower().replace(" ","_") + ".png")
348
+ g.savefig(filename)
349
+
350
+ return None
351
+
352
+
353
+
354
+ # sources -
355
+ #https://stackoverflow.com/questions/27988846/how-to-express-classes-on-the-axis-of-a-heatmap-in-seaborn
356
+ # https://matplotlib.org/3.1.1/tutorials/intermediate/legend_guide.html
357
+
358
+
359
+ def verify_line_no(filename, lines_read):
360
+ # Use Linux "wc -l" command to get the number of lines in the unopened file
361
+ wc = subprocess.check_output(['wc', '-l', filename]).decode("utf-8")
362
+ # Take that string, turn it into a list, extract the first item,
363
+ # and make that an int - this is the number of lines in the file
364
+ wc = int(wc.split()[0])
365
+ if lines_read != wc:
366
+ print("WARNING: '" + filename + "' has " + str(wc) +
367
+ " lines, but imported dataframe has "
368
+ + str(lines_read) + " (including header).")
369
+ return None
370
+
371
+
372
+ def rgb_tuple_from_str(rgb_str):
373
+ rgb_str = rgb_str.replace("(","").replace(")","").replace(" ","")
374
+ rgb = list(map(float,rgb_str.split(",")))
375
+ return tuple(rgb)
376
+
377
+ def color_dict_to_df(cd, column_name):
378
+ df = pd.DataFrame.from_dict(cd, orient = 'index')
379
+ df['rgb'] = df.apply(lambda row: (np.float64(row[0]), np.float64(row[1]), np.float64(row[2])), axis = 1)
380
+ df = df.drop(columns = [0,1,2])
381
+ df['hex'] = df.apply(lambda row: mplc.to_hex(row['rgb']), axis = 1)
382
+ df[column_name] = df.index
383
+ return df
384
+
385
+
386
+ # p-values that are less than or equal to 0.05
387
+ def p_add_star(row):
388
+ m = [str('{:0.3e}'.format(m)) + "*"
389
+ if m <= 0.05 \
390
+ else str('{:0.3e}'.format(m))
391
+ for m in row ]
392
+ return pd.Series(m)
393
+
394
+ # assigns a specific number of asterisks based on the thresholds
395
+ def p_to_star(row):
396
+ output = []
397
+ for item in row:
398
+ if item <= 0.001:
399
+ stars = 3
400
+ elif item <= 0.01:
401
+ stars = 2
402
+ elif item <= 0.05:
403
+ stars = 1
404
+ else:
405
+ stars = 0
406
+ value = ''
407
+ for i in range(stars):
408
+ value += '*'
409
+ output.append(value)
410
+ return pd.Series(output)
411
+
412
+
413
+
414
+ def plot_gaussian_distributions(df):
415
+ # Initialize thresholds list to store all calculated thresholds
416
+ all_thresholds = []
417
+
418
+ # Iterate over all columns except the first one (assuming the first one is non-numeric or an index)
419
+ for column in df.columns:
420
+ # Extract the marker data
421
+ marker_data = df[column]
422
+
423
+ # Calculating mean and standard deviation for each marker
424
+ m_mean, m_std = np.mean(marker_data), np.std(marker_data)
425
+
426
+ # Generating x values for the Gaussian curve
427
+ x_vals = np.linspace(marker_data.min(), marker_data.max(), 100)
428
+
429
+ # Calculating Gaussian distribution curve
430
+ gaussian_curve = (1 / (m_std * np.sqrt(2 * np.pi))) * np.exp(-(x_vals - m_mean) ** 2 / (2 * m_std ** 2))
431
+
432
+ # Creating figure for Gaussian distribution for each marker
433
+ fig = go.Figure()
434
+ fig.add_trace(go.Scatter(x=x_vals, y=gaussian_curve, mode='lines', name=f'{column} Gaussian Distribution'))
435
+ fig.update_layout(title=f'Gaussian Distribution for {column} Marker')
436
+
437
+ # Calculating thresholds based on each marker's distribution
438
+ seuil_1sigma = m_mean + m_std
439
+ seuil_2sigma = m_mean + 2 * m_std
440
+ seuil_3sigma = m_mean + 3 * m_std
441
+
442
+ # Display the figures with thresholds
443
+ fig.add_shape(type='line', x0=seuil_1sigma, y0=0, x1=seuil_1sigma, y1=np.max(gaussian_curve),
444
+ line=dict(color='red', dash='dash'), name=f'Seuil 1σ: {seuil_1sigma:.2f}')
445
+ fig.add_shape(type='line', x0=seuil_2sigma, y0=0, x1=seuil_2sigma, y1=np.max(gaussian_curve),
446
+ line=dict(color='green', dash='dash'), name=f'Seuil 2σ: {seuil_2sigma:.2f}')
447
+ fig.add_shape(type='line', x0=seuil_3sigma, y0=0, x1=seuil_3sigma, y1=np.max(gaussian_curve),
448
+ line=dict(color='blue', dash='dash'), name=f'Seuil 3σ: {seuil_3sigma:.2f}')
449
+
450
+ # Add markers and values to the plot
451
+ fig.add_trace(go.Scatter(x=[seuil_1sigma, seuil_2sigma, seuil_3sigma],
452
+ y=[0, 0, 0],
453
+ mode='markers+text',
454
+ text=[f'{seuil_1sigma:.2f}', f'{seuil_2sigma:.2f}', f'{seuil_3sigma:.2f}'],
455
+ textposition="top center",
456
+ marker=dict(size=10, color=['red', 'green', 'blue']),
457
+ name='Threshold Values'))
458
+
459
+ fig.show()
460
+
461
+ # Append thresholds for each marker to the list
462
+ all_thresholds.append((column, seuil_1sigma, seuil_2sigma, seuil_3sigma)) # Include the column name
463
+
464
+ # Return thresholds for all markers
465
+ return all_thresholds
466
+
467
+
468
+