choose correlation ranking plot
Browse files- folding_studio_demo/app.py +29 -7
- folding_studio_demo/correlate.py +50 -19
folding_studio_demo/app.py
CHANGED
@@ -10,7 +10,9 @@ from gradio_molecule3d import Molecule3D
|
|
10 |
from folding_studio_demo.correlate import (
|
11 |
SCORE_COLUMNS,
|
12 |
fake_predict_and_correlate,
|
13 |
-
|
|
|
|
|
14 |
)
|
15 |
from folding_studio_demo.predict import predict, predict_comparison
|
16 |
|
@@ -211,7 +213,15 @@ def create_correlation_tab():
|
|
211 |
with gr.Row():
|
212 |
prediction_dataframe = gr.Dataframe(label="Predicted Structures Data")
|
213 |
with gr.Row():
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
with gr.Row():
|
216 |
with gr.Column():
|
217 |
with gr.Row():
|
@@ -229,21 +239,33 @@ def create_correlation_tab():
|
|
229 |
fn=lambda x: fake_predict_and_correlate(
|
230 |
spr_data_with_scores, SCORE_COLUMNS, ["Antibody Name", "KD (nM)"]
|
231 |
),
|
232 |
-
inputs=
|
233 |
outputs=[prediction_dataframe, correlation_ranking_plot, correlation_plot],
|
234 |
)
|
235 |
|
236 |
-
def
|
237 |
-
return
|
238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
correlation_column.change(
|
240 |
-
fn=
|
241 |
inputs=[correlation_column, log_scale],
|
242 |
outputs=correlation_plot,
|
243 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
|
245 |
log_scale.change(
|
246 |
-
fn=
|
247 |
inputs=[correlation_column, log_scale],
|
248 |
outputs=correlation_plot,
|
249 |
)
|
|
|
10 |
from folding_studio_demo.correlate import (
|
11 |
SCORE_COLUMNS,
|
12 |
fake_predict_and_correlate,
|
13 |
+
make_regression_plot,
|
14 |
+
compute_correlation_data,
|
15 |
+
plot_correlation_ranking
|
16 |
)
|
17 |
from folding_studio_demo.predict import predict, predict_comparison
|
18 |
|
|
|
213 |
with gr.Row():
|
214 |
prediction_dataframe = gr.Dataframe(label="Predicted Structures Data")
|
215 |
with gr.Row():
|
216 |
+
with gr.Column():
|
217 |
+
correlation_type = gr.Radio(
|
218 |
+
choices=["Spearman", "Pearson", "R²"],
|
219 |
+
value="Spearman",
|
220 |
+
label="Correlation Type",
|
221 |
+
interactive=True
|
222 |
+
)
|
223 |
+
with gr.Column():
|
224 |
+
correlation_ranking_plot = gr.Plot(label="Correlation ranking")
|
225 |
with gr.Row():
|
226 |
with gr.Column():
|
227 |
with gr.Row():
|
|
|
239 |
fn=lambda x: fake_predict_and_correlate(
|
240 |
spr_data_with_scores, SCORE_COLUMNS, ["Antibody Name", "KD (nM)"]
|
241 |
),
|
242 |
+
inputs=[correlation_type],
|
243 |
outputs=[prediction_dataframe, correlation_ranking_plot, correlation_plot],
|
244 |
)
|
245 |
|
246 |
+
def update_regression_plot(score, use_log):
|
247 |
+
return make_regression_plot(spr_data_with_scores, score, use_log)
|
248 |
|
249 |
+
def update_correlation_plot(correlation_type):
|
250 |
+
logger.info(f"Updating correlation plot for {correlation_type}")
|
251 |
+
corr_data = compute_correlation_data(spr_data_with_scores, SCORE_COLUMNS)
|
252 |
+
logger.info(f"Correlation data: {corr_data}")
|
253 |
+
return plot_correlation_ranking(corr_data, correlation_type)
|
254 |
+
|
255 |
correlation_column.change(
|
256 |
+
fn=update_regression_plot,
|
257 |
inputs=[correlation_column, log_scale],
|
258 |
outputs=correlation_plot,
|
259 |
)
|
260 |
+
|
261 |
+
correlation_type.change(
|
262 |
+
fn=update_correlation_plot,
|
263 |
+
inputs=[correlation_type],
|
264 |
+
outputs=correlation_ranking_plot,
|
265 |
+
)
|
266 |
|
267 |
log_scale.change(
|
268 |
+
fn=update_regression_plot,
|
269 |
inputs=[correlation_column, log_scale],
|
270 |
outputs=correlation_plot,
|
271 |
)
|
folding_studio_demo/correlate.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
import logging
|
2 |
import pandas as pd
|
|
|
3 |
import numpy as np
|
4 |
import plotly.graph_objects as go
|
5 |
-
from scipy.stats import spearmanr
|
6 |
|
7 |
logger = logging.getLogger(__name__)
|
8 |
|
@@ -30,16 +31,32 @@ SCORE_COLUMNS = [
|
|
30 |
"interface_ptm_multimer"
|
31 |
]
|
32 |
|
33 |
-
def
|
34 |
-
|
|
|
|
|
|
|
|
|
35 |
corr_data = []
|
36 |
spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"])
|
37 |
kd_col = "KD (nM)"
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
corr_data = pd.DataFrame(corr_data)
|
45 |
# Find the lines in corr_data with NaN values and remove them
|
@@ -47,34 +64,48 @@ def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: l
|
|
47 |
# Sort correlation data by correlation value
|
48 |
corr_data = corr_data.sort_values('correlation', ascending=True)
|
49 |
|
|
|
|
|
|
|
|
|
|
|
50 |
# Create bar plot of correlations
|
|
|
51 |
corr_ranking_plot = go.Figure(data=[
|
52 |
go.Bar(
|
53 |
-
x=
|
54 |
-
y=
|
55 |
-
name=
|
|
|
56 |
orientation='h',
|
57 |
hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>"
|
58 |
)
|
59 |
])
|
60 |
corr_ranking_plot.update_layout(
|
61 |
title="Correlation with Binding Affinity",
|
62 |
-
yaxis_title="Score
|
63 |
-
xaxis_title=
|
64 |
template="simple_white",
|
65 |
showlegend=False
|
66 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
cols_to_show = main_cols[:]
|
69 |
cols_to_show.extend(score_cols)
|
70 |
|
71 |
-
corr_plot =
|
72 |
|
73 |
return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot, corr_plot
|
74 |
|
75 |
-
def
|
76 |
-
"""Select the
|
77 |
-
# corr_plot is a scatter plot of the
|
78 |
scatter = go.Scatter(
|
79 |
x=spr_data_with_scores["KD (nM)"],
|
80 |
y=spr_data_with_scores[score],
|
@@ -97,11 +128,11 @@ def make_correlation_plot(spr_data_with_scores: pd.DataFrame, score: str, use_lo
|
|
97 |
),
|
98 |
xaxis_type="log" if use_log else "linear" # Set x-axis to logarithmic scale
|
99 |
)
|
100 |
-
# compute the
|
101 |
corr_line = np.polyfit(spr_data_with_scores["KD (nM)"], spr_data_with_scores[score], 1)
|
102 |
corr_line_x = np.linspace(min(spr_data_with_scores["KD (nM)"]), max(spr_data_with_scores["KD (nM)"]), 100)
|
103 |
corr_line_y = corr_line[0] * corr_line_x + corr_line[1]
|
104 |
-
# add the
|
105 |
corr_plot.add_trace(go.Scatter(
|
106 |
x=corr_line_x,
|
107 |
y=corr_line_y,
|
|
|
1 |
import logging
|
2 |
import pandas as pd
|
3 |
+
from pathlib import Path
|
4 |
import numpy as np
|
5 |
import plotly.graph_objects as go
|
6 |
+
from scipy.stats import spearmanr, pearsonr, linregress
|
7 |
|
8 |
logger = logging.getLogger(__name__)
|
9 |
|
|
|
31 |
"interface_ptm_multimer"
|
32 |
]
|
33 |
|
34 |
+
def compute_correlation_data(spr_data_with_scores: pd.DataFrame, score_cols: list[str]) -> pd.DataFrame:
|
35 |
+
corr_data_file = Path("corr_data.csv")
|
36 |
+
if corr_data_file.exists():
|
37 |
+
logger.info(f"Loading correlation data from {corr_data_file}")
|
38 |
+
return pd.read_csv(corr_data_file)
|
39 |
+
|
40 |
corr_data = []
|
41 |
spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"])
|
42 |
kd_col = "KD (nM)"
|
43 |
+
corr_funcs = {}
|
44 |
+
corr_funcs["Spearman"] = spearmanr
|
45 |
+
corr_funcs["Pearson"] = pearsonr
|
46 |
+
corr_funcs["R²"] = linregress
|
47 |
+
for correlation_type, corr_func in corr_funcs.items():
|
48 |
+
for score_col in score_cols:
|
49 |
+
logger.info(f"Computing {correlation_type} correlation between {score_col} and KD (nM)")
|
50 |
+
res = corr_func(spr_data_with_scores[kd_col], spr_data_with_scores[score_col])
|
51 |
+
logger.info(f"Correlation function: {corr_func}")
|
52 |
+
correlation_value = res.rvalue**2 if correlation_type == "R²" else res.statistic
|
53 |
+
corr_data.append({
|
54 |
+
"correlation_type": correlation_type,
|
55 |
+
"score": score_col,
|
56 |
+
"correlation": correlation_value,
|
57 |
+
"p-value": res.pvalue
|
58 |
+
})
|
59 |
+
logger.info(f"Correlation {correlation_type} between {score_col} and KD (nM): {correlation_value}")
|
60 |
|
61 |
corr_data = pd.DataFrame(corr_data)
|
62 |
# Find the lines in corr_data with NaN values and remove them
|
|
|
64 |
# Sort correlation data by correlation value
|
65 |
corr_data = corr_data.sort_values('correlation', ascending=True)
|
66 |
|
67 |
+
corr_data.to_csv("corr_data.csv", index=False)
|
68 |
+
|
69 |
+
return corr_data
|
70 |
+
|
71 |
+
def plot_correlation_ranking(corr_data: pd.DataFrame, correlation_type: str) -> go.Figure:
|
72 |
# Create bar plot of correlations
|
73 |
+
data = corr_data[corr_data["correlation_type"] == correlation_type]
|
74 |
corr_ranking_plot = go.Figure(data=[
|
75 |
go.Bar(
|
76 |
+
x=data["correlation"],
|
77 |
+
y=data["score"],
|
78 |
+
name=correlation_type,
|
79 |
+
text=data["correlation"],
|
80 |
orientation='h',
|
81 |
hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>"
|
82 |
)
|
83 |
])
|
84 |
corr_ranking_plot.update_layout(
|
85 |
title="Correlation with Binding Affinity",
|
86 |
+
yaxis_title="Score",
|
87 |
+
xaxis_title=correlation_type,
|
88 |
template="simple_white",
|
89 |
showlegend=False
|
90 |
)
|
91 |
+
return corr_ranking_plot
|
92 |
+
|
93 |
+
def fake_predict_and_correlate(spr_data_with_scores: pd.DataFrame, score_cols: list[str], main_cols: list[str]) -> tuple[pd.DataFrame, go.Figure]:
|
94 |
+
"""Fake predict structures of all complexes and correlate the results."""
|
95 |
+
|
96 |
+
corr_data = compute_correlation_data(spr_data_with_scores, score_cols)
|
97 |
+
corr_ranking_plot = plot_correlation_ranking(corr_data, "Spearman")
|
98 |
|
99 |
cols_to_show = main_cols[:]
|
100 |
cols_to_show.extend(score_cols)
|
101 |
|
102 |
+
corr_plot = make_regression_plot(spr_data_with_scores, score_cols[0], use_log=False)
|
103 |
|
104 |
return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot, corr_plot
|
105 |
|
106 |
+
def make_regression_plot(spr_data_with_scores: pd.DataFrame, score: str, use_log: bool) -> go.Figure:
|
107 |
+
"""Select the regression plot to display."""
|
108 |
+
# corr_plot is a scatter plot of the regression between the binding affinity and each of the scores
|
109 |
scatter = go.Scatter(
|
110 |
x=spr_data_with_scores["KD (nM)"],
|
111 |
y=spr_data_with_scores[score],
|
|
|
128 |
),
|
129 |
xaxis_type="log" if use_log else "linear" # Set x-axis to logarithmic scale
|
130 |
)
|
131 |
+
# compute the regression line
|
132 |
corr_line = np.polyfit(spr_data_with_scores["KD (nM)"], spr_data_with_scores[score], 1)
|
133 |
corr_line_x = np.linspace(min(spr_data_with_scores["KD (nM)"]), max(spr_data_with_scores["KD (nM)"]), 100)
|
134 |
corr_line_y = corr_line[0] * corr_line_x + corr_line[1]
|
135 |
+
# add the regression line to the plot
|
136 |
corr_plot.add_trace(go.Scatter(
|
137 |
x=corr_line_x,
|
138 |
y=corr_line_y,
|