jfaustin commited on
Commit
8192214
·
1 Parent(s): ecc0f00

choose correlation ranking plot

Browse files
folding_studio_demo/app.py CHANGED
@@ -10,7 +10,9 @@ from gradio_molecule3d import Molecule3D
10
  from folding_studio_demo.correlate import (
11
  SCORE_COLUMNS,
12
  fake_predict_and_correlate,
13
- make_correlation_plot,
 
 
14
  )
15
  from folding_studio_demo.predict import predict, predict_comparison
16
 
@@ -211,7 +213,15 @@ def create_correlation_tab():
211
  with gr.Row():
212
  prediction_dataframe = gr.Dataframe(label="Predicted Structures Data")
213
  with gr.Row():
214
- correlation_ranking_plot = gr.Plot(label="Correlation ranking")
 
 
 
 
 
 
 
 
215
  with gr.Row():
216
  with gr.Column():
217
  with gr.Row():
@@ -229,21 +239,33 @@ def create_correlation_tab():
229
  fn=lambda x: fake_predict_and_correlate(
230
  spr_data_with_scores, SCORE_COLUMNS, ["Antibody Name", "KD (nM)"]
231
  ),
232
- inputs=None,
233
  outputs=[prediction_dataframe, correlation_ranking_plot, correlation_plot],
234
  )
235
 
236
- def update_plot(score, use_log):
237
- return make_correlation_plot(spr_data_with_scores, score, use_log)
238
 
 
 
 
 
 
 
239
  correlation_column.change(
240
- fn=update_plot,
241
  inputs=[correlation_column, log_scale],
242
  outputs=correlation_plot,
243
  )
 
 
 
 
 
 
244
 
245
  log_scale.change(
246
- fn=update_plot,
247
  inputs=[correlation_column, log_scale],
248
  outputs=correlation_plot,
249
  )
 
10
  from folding_studio_demo.correlate import (
11
  SCORE_COLUMNS,
12
  fake_predict_and_correlate,
13
+ make_regression_plot,
14
+ compute_correlation_data,
15
+ plot_correlation_ranking
16
  )
17
  from folding_studio_demo.predict import predict, predict_comparison
18
 
 
213
  with gr.Row():
214
  prediction_dataframe = gr.Dataframe(label="Predicted Structures Data")
215
  with gr.Row():
216
+ with gr.Column():
217
+ correlation_type = gr.Radio(
218
+ choices=["Spearman", "Pearson", "R²"],
219
+ value="Spearman",
220
+ label="Correlation Type",
221
+ interactive=True
222
+ )
223
+ with gr.Column():
224
+ correlation_ranking_plot = gr.Plot(label="Correlation ranking")
225
  with gr.Row():
226
  with gr.Column():
227
  with gr.Row():
 
239
  fn=lambda x: fake_predict_and_correlate(
240
  spr_data_with_scores, SCORE_COLUMNS, ["Antibody Name", "KD (nM)"]
241
  ),
242
+ inputs=[correlation_type],
243
  outputs=[prediction_dataframe, correlation_ranking_plot, correlation_plot],
244
  )
245
 
246
+ def update_regression_plot(score, use_log):
247
+ return make_regression_plot(spr_data_with_scores, score, use_log)
248
 
249
+ def update_correlation_plot(correlation_type):
250
+ logger.info(f"Updating correlation plot for {correlation_type}")
251
+ corr_data = compute_correlation_data(spr_data_with_scores, SCORE_COLUMNS)
252
+ logger.info(f"Correlation data: {corr_data}")
253
+ return plot_correlation_ranking(corr_data, correlation_type)
254
+
255
  correlation_column.change(
256
+ fn=update_regression_plot,
257
  inputs=[correlation_column, log_scale],
258
  outputs=correlation_plot,
259
  )
260
+
261
+ correlation_type.change(
262
+ fn=update_correlation_plot,
263
+ inputs=[correlation_type],
264
+ outputs=correlation_ranking_plot,
265
+ )
266
 
267
  log_scale.change(
268
+ fn=update_regression_plot,
269
  inputs=[correlation_column, log_scale],
270
  outputs=correlation_plot,
271
  )
folding_studio_demo/correlate.py CHANGED
@@ -1,8 +1,9 @@
1
  import logging
2
  import pandas as pd
 
3
  import numpy as np
4
  import plotly.graph_objects as go
5
- from scipy.stats import spearmanr
6
 
7
  logger = logging.getLogger(__name__)
8
 
@@ -30,16 +31,32 @@ SCORE_COLUMNS = [
30
  "interface_ptm_multimer"
31
  ]
32
 
33
- def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: list[str], main_cols: list[str]) -> tuple[pd.DataFrame, go.Figure]:
34
- """Fake predict structures of all complexes and correlate the results."""
 
 
 
 
35
  corr_data = []
36
  spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"])
37
  kd_col = "KD (nM)"
38
- for score_col in score_cols:
39
- logger.info(f"Computing correlation between {score_col} and KD (nM)")
40
- res = spearmanr(spr_data_with_scores[kd_col], spr_data_with_scores[score_col])
41
- corr_data.append({"score": score_col, "correlation": res.statistic, "p-value": res.pvalue})
42
- logger.info(f"Correlation between {score_col} and KD (nM): {res.statistic}")
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  corr_data = pd.DataFrame(corr_data)
45
  # Find the lines in corr_data with NaN values and remove them
@@ -47,34 +64,48 @@ def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: l
47
  # Sort correlation data by correlation value
48
  corr_data = corr_data.sort_values('correlation', ascending=True)
49
 
 
 
 
 
 
50
  # Create bar plot of correlations
 
51
  corr_ranking_plot = go.Figure(data=[
52
  go.Bar(
53
- x=corr_data["correlation"],
54
- y=corr_data["score"],
55
- name="correlation",
 
56
  orientation='h',
57
  hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>"
58
  )
59
  ])
60
  corr_ranking_plot.update_layout(
61
  title="Correlation with Binding Affinity",
62
- yaxis_title="Score Type",
63
- xaxis_title="Spearman Correlation",
64
  template="simple_white",
65
  showlegend=False
66
  )
 
 
 
 
 
 
 
67
 
68
  cols_to_show = main_cols[:]
69
  cols_to_show.extend(score_cols)
70
 
71
- corr_plot = make_correlation_plot(spr_data_with_scores, score_cols[0], use_log=False)
72
 
73
  return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot, corr_plot
74
 
75
- def make_correlation_plot(spr_data_with_scores: pd.DataFrame, score: str, use_log: bool) -> go.Figure:
76
- """Select the correlation plot to display."""
77
- # corr_plot is a scatter plot of the correlation between the binding affinity and each of the scores
78
  scatter = go.Scatter(
79
  x=spr_data_with_scores["KD (nM)"],
80
  y=spr_data_with_scores[score],
@@ -97,11 +128,11 @@ def make_correlation_plot(spr_data_with_scores: pd.DataFrame, score: str, use_lo
97
  ),
98
  xaxis_type="log" if use_log else "linear" # Set x-axis to logarithmic scale
99
  )
100
- # compute the correlation line
101
  corr_line = np.polyfit(spr_data_with_scores["KD (nM)"], spr_data_with_scores[score], 1)
102
  corr_line_x = np.linspace(min(spr_data_with_scores["KD (nM)"]), max(spr_data_with_scores["KD (nM)"]), 100)
103
  corr_line_y = corr_line[0] * corr_line_x + corr_line[1]
104
- # add the correlation line to the plot
105
  corr_plot.add_trace(go.Scatter(
106
  x=corr_line_x,
107
  y=corr_line_y,
 
1
  import logging
2
  import pandas as pd
3
+ from pathlib import Path
4
  import numpy as np
5
  import plotly.graph_objects as go
6
+ from scipy.stats import spearmanr, pearsonr, linregress
7
 
8
  logger = logging.getLogger(__name__)
9
 
 
31
  "interface_ptm_multimer"
32
  ]
33
 
34
+ def compute_correlation_data(spr_data_with_scores: pd.DataFrame, score_cols: list[str]) -> pd.DataFrame:
35
+ corr_data_file = Path("corr_data.csv")
36
+ if corr_data_file.exists():
37
+ logger.info(f"Loading correlation data from {corr_data_file}")
38
+ return pd.read_csv(corr_data_file)
39
+
40
  corr_data = []
41
  spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"])
42
  kd_col = "KD (nM)"
43
+ corr_funcs = {}
44
+ corr_funcs["Spearman"] = spearmanr
45
+ corr_funcs["Pearson"] = pearsonr
46
+ corr_funcs[""] = linregress
47
+ for correlation_type, corr_func in corr_funcs.items():
48
+ for score_col in score_cols:
49
+ logger.info(f"Computing {correlation_type} correlation between {score_col} and KD (nM)")
50
+ res = corr_func(spr_data_with_scores[kd_col], spr_data_with_scores[score_col])
51
+ logger.info(f"Correlation function: {corr_func}")
52
+ correlation_value = res.rvalue**2 if correlation_type == "R²" else res.statistic
53
+ corr_data.append({
54
+ "correlation_type": correlation_type,
55
+ "score": score_col,
56
+ "correlation": correlation_value,
57
+ "p-value": res.pvalue
58
+ })
59
+ logger.info(f"Correlation {correlation_type} between {score_col} and KD (nM): {correlation_value}")
60
 
61
  corr_data = pd.DataFrame(corr_data)
62
  # Find the lines in corr_data with NaN values and remove them
 
64
  # Sort correlation data by correlation value
65
  corr_data = corr_data.sort_values('correlation', ascending=True)
66
 
67
+ corr_data.to_csv("corr_data.csv", index=False)
68
+
69
+ return corr_data
70
+
71
+ def plot_correlation_ranking(corr_data: pd.DataFrame, correlation_type: str) -> go.Figure:
72
  # Create bar plot of correlations
73
+ data = corr_data[corr_data["correlation_type"] == correlation_type]
74
  corr_ranking_plot = go.Figure(data=[
75
  go.Bar(
76
+ x=data["correlation"],
77
+ y=data["score"],
78
+ name=correlation_type,
79
+ text=data["correlation"],
80
  orientation='h',
81
  hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>"
82
  )
83
  ])
84
  corr_ranking_plot.update_layout(
85
  title="Correlation with Binding Affinity",
86
+ yaxis_title="Score",
87
+ xaxis_title=correlation_type,
88
  template="simple_white",
89
  showlegend=False
90
  )
91
+ return corr_ranking_plot
92
+
93
+ def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: list[str], main_cols: list[str]) -> tuple[pd.DataFrame, go.Figure]:
94
+ """Fake predict structures of all complexes and correlate the results."""
95
+
96
+ corr_data = compute_correlation_data(spr_data_with_scores, score_cols)
97
+ corr_ranking_plot = plot_correlation_ranking(corr_data, "Spearman")
98
 
99
  cols_to_show = main_cols[:]
100
  cols_to_show.extend(score_cols)
101
 
102
+ corr_plot = make_regression_plot(spr_data_with_scores, score_cols[0], use_log=False)
103
 
104
  return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot, corr_plot
105
 
106
+ def make_regression_plot(spr_data_with_scores: pd.DataFrame, score: str, use_log: bool) -> go.Figure:
107
+ """Select the regression plot to display."""
108
+ # corr_plot is a scatter plot of the regression between the binding affinity and each of the scores
109
  scatter = go.Scatter(
110
  x=spr_data_with_scores["KD (nM)"],
111
  y=spr_data_with_scores[score],
 
128
  ),
129
  xaxis_type="log" if use_log else "linear" # Set x-axis to logarithmic scale
130
  )
131
+ # compute the regression line
132
  corr_line = np.polyfit(spr_data_with_scores["KD (nM)"], spr_data_with_scores[score], 1)
133
  corr_line_x = np.linspace(min(spr_data_with_scores["KD (nM)"]), max(spr_data_with_scores["KD (nM)"]), 100)
134
  corr_line_y = corr_line[0] * corr_line_x + corr_line[1]
135
+ # add the regression line to the plot
136
  corr_plot.add_trace(go.Scatter(
137
  x=corr_line_x,
138
  y=corr_line_y,