select correlation plots
Browse files- folding_studio_demo/app.py +18 -10
- folding_studio_demo/correlate.py +29 -18
folding_studio_demo/app.py
CHANGED
@@ -8,7 +8,7 @@ from gradio_molecule3d import Molecule3D
|
|
8 |
import pandas as pd
|
9 |
|
10 |
from folding_studio_demo.predict import predict
|
11 |
-
from folding_studio_demo.correlate import fake_predict_and_correlate, SCORE_COLUMNS
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
@@ -167,21 +167,29 @@ def create_correlation_tab():
|
|
167 |
|
168 |
gr.Markdown("# Prediction and correlation")
|
169 |
with gr.Row():
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
|
|
178 |
|
179 |
fake_predict_btn.click(
|
180 |
fn=lambda x: fake_predict_and_correlate(spr_data_with_scores, SCORE_COLUMNS),
|
181 |
inputs=None,
|
182 |
-
outputs=[prediction_dataframe, correlation_ranking_plot
|
183 |
)
|
184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
def __main__():
|
186 |
with gr.Blocks(title="Folding Studio Demo") as demo:
|
187 |
gr.Markdown(
|
|
|
8 |
import pandas as pd
|
9 |
|
10 |
from folding_studio_demo.predict import predict
|
11 |
+
from folding_studio_demo.correlate import fake_predict_and_correlate, SCORE_COLUMNS, select_correlation_plot
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
|
|
167 |
|
168 |
gr.Markdown("# Prediction and correlation")
|
169 |
with gr.Row():
|
170 |
+
fake_predict_btn = gr.Button("Predict structures of all complexes")
|
171 |
+
with gr.Row():
|
172 |
+
prediction_dataframe = gr.Dataframe(label="Predicted Structures Data")
|
173 |
+
with gr.Row():
|
174 |
+
correlation_ranking_plot = gr.Plot(label="Correlation ranking")
|
175 |
+
with gr.Row():
|
176 |
+
# User can select the columns to display in the correlation plot
|
177 |
+
correlation_column = gr.Dropdown(label="Score data to display", choices=SCORE_COLUMNS, multiselect=False)
|
178 |
+
correlation_plot = gr.Plot(label="Correlation with binding affinity")
|
179 |
|
180 |
fake_predict_btn.click(
|
181 |
fn=lambda x: fake_predict_and_correlate(spr_data_with_scores, SCORE_COLUMNS),
|
182 |
inputs=None,
|
183 |
+
outputs=[prediction_dataframe, correlation_ranking_plot]
|
184 |
)
|
185 |
|
186 |
+
# Call function to update the correlation plot when the user selects the columns
|
187 |
+
correlation_column.change(
|
188 |
+
fn=lambda score: select_correlation_plot(spr_data_with_scores, score),
|
189 |
+
inputs=correlation_column,
|
190 |
+
outputs=correlation_plot
|
191 |
+
)
|
192 |
+
|
193 |
def __main__():
|
194 |
with gr.Blocks(title="Folding Studio Demo") as demo:
|
195 |
gr.Markdown(
|
folding_studio_demo/correlate.py
CHANGED
@@ -65,26 +65,37 @@ def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: l
|
|
65 |
showlegend=False
|
66 |
)
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
# corr_plot is a scatter plot of the correlation between the binding affinity and each of the scores
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>"
|
78 |
-
)
|
79 |
-
)
|
80 |
-
corr_plot = go.Figure(data=scatters)
|
81 |
corr_plot.update_layout(
|
82 |
xaxis_title="KD (nM)",
|
83 |
-
yaxis_title=
|
84 |
template="simple_white",
|
85 |
-
xaxis_type="log" # Set x-axis to logarithmic scale
|
86 |
)
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
showlegend=False
|
66 |
)
|
67 |
|
68 |
+
cols_to_show = [kd_col]
|
69 |
+
cols_to_show.extend(score_cols)
|
70 |
+
|
71 |
+
return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot
|
72 |
+
|
73 |
+
def select_correlation_plot(spr_data_with_scores: pd.DataFrame, score: str) -> go.Figure:
|
74 |
+
"""Select the correlation plot to display."""
|
75 |
# corr_plot is a scatter plot of the correlation between the binding affinity and each of the scores
|
76 |
+
scatter = go.Scatter(
|
77 |
+
x=spr_data_with_scores["KD (nM)"],
|
78 |
+
y=spr_data_with_scores[score],
|
79 |
+
name=f"KD (nM) vs {score}",
|
80 |
+
mode='markers', # Only show markers/dots, no lines
|
81 |
+
hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>"
|
82 |
+
)
|
83 |
+
corr_plot = go.Figure(data=scatter)
|
|
|
|
|
|
|
|
|
84 |
corr_plot.update_layout(
|
85 |
xaxis_title="KD (nM)",
|
86 |
+
yaxis_title=score,
|
87 |
template="simple_white",
|
88 |
+
# xaxis_type="log" # Set x-axis to logarithmic scale
|
89 |
)
|
90 |
+
# compute the correlation line
|
91 |
+
corr_line = np.polyfit(spr_data_with_scores["KD (nM)"], spr_data_with_scores[score], 1)
|
92 |
+
corr_line_x = np.linspace(min(spr_data_with_scores["KD (nM)"]), max(spr_data_with_scores["KD (nM)"]), 100)
|
93 |
+
corr_line_y = corr_line[0] * corr_line_x + corr_line[1]
|
94 |
+
# add the correlation line to the plot
|
95 |
+
corr_plot.add_trace(go.Scatter(
|
96 |
+
x=corr_line_x,
|
97 |
+
y=corr_line_y,
|
98 |
+
mode='lines',
|
99 |
+
name=f"Correlation line for {score}"
|
100 |
+
))
|
101 |
+
return corr_plot
|