jfaustin commited on
Commit
3a103f2
·
1 Parent(s): 216492b

select correlation plots

Browse files
folding_studio_demo/app.py CHANGED
@@ -8,7 +8,7 @@ from gradio_molecule3d import Molecule3D
8
  import pandas as pd
9
 
10
  from folding_studio_demo.predict import predict
11
- from folding_studio_demo.correlate import fake_predict_and_correlate, SCORE_COLUMNS
12
 
13
  logger = logging.getLogger(__name__)
14
 
@@ -167,21 +167,29 @@ def create_correlation_tab():
167
 
168
  gr.Markdown("# Prediction and correlation")
169
  with gr.Row():
170
- with gr.Column():
171
- with gr.Row():
172
- fake_predict_btn = gr.Button("Predict structures of all complexes")
173
- with gr.Row():
174
- prediction_dataframe = gr.Dataframe(label="Predicted Structures Data")
175
- with gr.Column():
176
- correlation_ranking_plot = gr.Plot(label="Correlation ranking")
177
- correlation_plot = gr.Plot(label="Correlation with binding affinity")
 
178
 
179
  fake_predict_btn.click(
180
  fn=lambda x: fake_predict_and_correlate(spr_data_with_scores, SCORE_COLUMNS),
181
  inputs=None,
182
- outputs=[prediction_dataframe, correlation_ranking_plot, correlation_plot]
183
  )
184
 
 
 
 
 
 
 
 
185
  def __main__():
186
  with gr.Blocks(title="Folding Studio Demo") as demo:
187
  gr.Markdown(
 
8
  import pandas as pd
9
 
10
  from folding_studio_demo.predict import predict
11
+ from folding_studio_demo.correlate import fake_predict_and_correlate, SCORE_COLUMNS, select_correlation_plot
12
 
13
  logger = logging.getLogger(__name__)
14
 
 
167
 
168
  gr.Markdown("# Prediction and correlation")
169
  with gr.Row():
170
+ fake_predict_btn = gr.Button("Predict structures of all complexes")
171
+ with gr.Row():
172
+ prediction_dataframe = gr.Dataframe(label="Predicted Structures Data")
173
+ with gr.Row():
174
+ correlation_ranking_plot = gr.Plot(label="Correlation ranking")
175
+ with gr.Row():
176
+ # User can select the columns to display in the correlation plot
177
+ correlation_column = gr.Dropdown(label="Score data to display", choices=SCORE_COLUMNS, multiselect=False)
178
+ correlation_plot = gr.Plot(label="Correlation with binding affinity")
179
 
180
  fake_predict_btn.click(
181
  fn=lambda x: fake_predict_and_correlate(spr_data_with_scores, SCORE_COLUMNS),
182
  inputs=None,
183
+ outputs=[prediction_dataframe, correlation_ranking_plot]
184
  )
185
 
186
+ # Call function to update the correlation plot when the user selects the columns
187
+ correlation_column.change(
188
+ fn=lambda score: select_correlation_plot(spr_data_with_scores, score),
189
+ inputs=correlation_column,
190
+ outputs=correlation_plot
191
+ )
192
+
193
  def __main__():
194
  with gr.Blocks(title="Folding Studio Demo") as demo:
195
  gr.Markdown(
folding_studio_demo/correlate.py CHANGED
@@ -65,26 +65,37 @@ def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: l
65
  showlegend=False
66
  )
67
 
 
 
 
 
 
 
 
68
  # corr_plot is a scatter plot of the correlation between the binding affinity and each of the scores
69
- scatters = []
70
- for score_col in score_cols:
71
- scatters.append(
72
- go.Scatter(
73
- x=spr_data_with_scores[kd_col],
74
- y=spr_data_with_scores[score_col],
75
- name=f"{kd_col} vs {score_col}",
76
- mode='markers', # Only show markers/dots, no lines
77
- hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>"
78
- )
79
- )
80
- corr_plot = go.Figure(data=scatters)
81
  corr_plot.update_layout(
82
  xaxis_title="KD (nM)",
83
- yaxis_title="Score",
84
  template="simple_white",
85
- xaxis_type="log" # Set x-axis to logarithmic scale
86
  )
87
- cols_to_show = [kd_col]
88
- cols_to_show.extend(score_cols)
89
-
90
- return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot, corr_plot
 
 
 
 
 
 
 
 
 
65
  showlegend=False
66
  )
67
 
68
+ cols_to_show = [kd_col]
69
+ cols_to_show.extend(score_cols)
70
+
71
+ return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot
72
+
73
+ def select_correlation_plot(spr_data_with_scores: pd.DataFrame, score: str) -> go.Figure:
74
+ """Select the correlation plot to display."""
75
  # corr_plot is a scatter plot of the correlation between the binding affinity and each of the scores
76
+ scatter = go.Scatter(
77
+ x=spr_data_with_scores["KD (nM)"],
78
+ y=spr_data_with_scores[score],
79
+ name=f"KD (nM) vs {score}",
80
+ mode='markers', # Only show markers/dots, no lines
81
+ hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>"
82
+ )
83
+ corr_plot = go.Figure(data=scatter)
 
 
 
 
84
  corr_plot.update_layout(
85
  xaxis_title="KD (nM)",
86
+ yaxis_title=score,
87
  template="simple_white",
88
+ # xaxis_type="log" # Set x-axis to logarithmic scale
89
  )
90
+ # compute the correlation line
91
+ corr_line = np.polyfit(spr_data_with_scores["KD (nM)"], spr_data_with_scores[score], 1)
92
+ corr_line_x = np.linspace(min(spr_data_with_scores["KD (nM)"]), max(spr_data_with_scores["KD (nM)"]), 100)
93
+ corr_line_y = corr_line[0] * corr_line_x + corr_line[1]
94
+ # add the correlation line to the plot
95
+ corr_plot.add_trace(go.Scatter(
96
+ x=corr_line_x,
97
+ y=corr_line_y,
98
+ mode='lines',
99
+ name=f"Correlation line for {score}"
100
+ ))
101
+ return corr_plot