loodvanniekerkginkgo commited on
Commit
d6a0c44
·
1 Parent(s): 393870b

Added new validation for very high spearman correlations

Browse files
Files changed (6) hide show
  1. about.py +2 -1
  2. app.py +20 -5
  3. constants.py +9 -9
  4. evaluation.py +152 -0
  5. submit.py +2 -0
  6. validation.py +16 -3
about.py CHANGED
@@ -155,7 +155,8 @@ We may release private test set results at intermediate points during the compet
155
  ## Cross-validation
156
 
157
  For the cross-validation metrics (if training only on the GDPa1 dataset), use the `"hierarchical_cluster_IgG_isotype_stratified_fold"` column to split the dataset into folds and make predictions for each of the folds.
158
- Submit a CSV file in the same format but also containing the `"hierarchical_cluster_IgG_isotype_stratified_fold"` column.
 
159
 
160
  Submissions close on **1 November 2025**.
161
  """
 
155
  ## Cross-validation
156
 
157
  For the cross-validation metrics (if training only on the GDPa1 dataset), use the `"hierarchical_cluster_IgG_isotype_stratified_fold"` column to split the dataset into folds and make predictions for each of the folds.
158
+ Submit a CSV file in the same format but also containing the `"hierarchical_cluster_IgG_isotype_stratified_fold"` column.
159
+ We will be releasing a tutorial on cross-validation shortly.
160
 
161
  Submissions close on **1 November 2025**.
162
  """
app.py CHANGED
@@ -170,8 +170,8 @@ with gr.Blocks(theme=gr.themes.Default(text_size=sizes.text_lg)) as demo:
170
 
171
  with gr.TabItem(SUBMIT_TAB_NAME, elem_id="boundary-benchmark-tab-table"):
172
  gr.Markdown(SUBMIT_INTRUCTIONS)
173
- submission_type_state = gr.State(value="GDPa1")
174
- download_file_state = gr.State(value=EXAMPLE_FILE_DICT["GDPa1"])
175
 
176
  with gr.Row():
177
  with gr.Column():
@@ -204,16 +204,31 @@ with gr.Blocks(theme=gr.themes.Default(text_size=sizes.text_lg)) as demo:
204
  placeholder="Enter your registration code",
205
  info="If you did not receive a registration code, please sign up on the <a href='https://datapoints.ginkgo.bio/ai-competitions/2025-abdev-competition'>Competition Registration page</a> or email <a href='mailto:[email protected]'>[email protected]</a>.",
206
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  with gr.Column():
208
  submission_type_dropdown = gr.Dropdown(
209
  choices=["GDPa1", "GDPa1_cross_validation", "Heldout Test Set"],
210
- value="GDPa1",
211
  label="Submission Type",
212
  info=f"Choose the dataset corresponding to the track you're participating in. See the '{ABOUT_TAB_NAME}' tab for details.",
213
  )
214
  download_button = gr.DownloadButton(
215
  label="📥 Download example submission CSV for GDPa1",
216
- value=EXAMPLE_FILE_DICT["GDPa1"],
217
  variant="secondary",
218
  )
219
  submission_file = gr.File(label="Submission CSV")
@@ -291,4 +306,4 @@ with gr.Blocks(theme=gr.themes.Default(text_size=sizes.text_lg)) as demo:
291
  )
292
 
293
  if __name__ == "__main__":
294
- demo.launch(ssr_mode=False)
 
170
 
171
  with gr.TabItem(SUBMIT_TAB_NAME, elem_id="boundary-benchmark-tab-table"):
172
  gr.Markdown(SUBMIT_INTRUCTIONS)
173
+ submission_type_state = gr.State(value="GDPa1_cross_validation")
174
+ download_file_state = gr.State(value=EXAMPLE_FILE_DICT["GDPa1_cross_validation"])
175
 
176
  with gr.Row():
177
  with gr.Column():
 
204
  placeholder="Enter your registration code",
205
  info="If you did not receive a registration code, please sign up on the <a href='https://datapoints.ginkgo.bio/ai-competitions/2025-abdev-competition'>Competition Registration page</a> or email <a href='mailto:[email protected]'>[email protected]</a>.",
206
  )
207
+
208
+ # Extra validation / warning
209
+ # Add the conditional warning checkbox
210
+ high_corr_warning = gr.Markdown(
211
+ value="",
212
+ visible=False,
213
+ elem_classes=["warning-box"]
214
+ )
215
+ high_corr_checkbox = gr.Checkbox(
216
+ label="I understand this may be overfitting",
217
+ value=False,
218
+ visible=False,
219
+ info="This checkbox will appear if your submission shows suspiciously high correlations (>0.9).",
220
+ )
221
+
222
  with gr.Column():
223
  submission_type_dropdown = gr.Dropdown(
224
  choices=["GDPa1", "GDPa1_cross_validation", "Heldout Test Set"],
225
+ value="GDPa1_cross_validation",
226
  label="Submission Type",
227
  info=f"Choose the dataset corresponding to the track you're participating in. See the '{ABOUT_TAB_NAME}' tab for details.",
228
  )
229
  download_button = gr.DownloadButton(
230
  label="📥 Download example submission CSV for GDPa1",
231
+ value=EXAMPLE_FILE_DICT["GDPa1_cross_validation"],
232
  variant="secondary",
233
  )
234
  submission_file = gr.File(label="Submission CSV")
 
306
  )
307
 
308
  if __name__ == "__main__":
309
+ demo.launch(ssr_mode=False, share=True)
constants.py CHANGED
@@ -28,6 +28,13 @@ ASSAY_EMOJIS = {
28
  "Tm2": "🌡️",
29
  "Titer": "🧪",
30
  }
 
 
 
 
 
 
 
31
  # Tabs with emojis
32
  ABOUT_TAB_NAME = "📖 About / Rules"
33
  FAQ_TAB_NAME = "❓ FAQs"
@@ -50,15 +57,8 @@ EXAMPLE_FILE_DICT = {
50
  "GDPa1_cross_validation": "data/example-predictions-cv.csv",
51
  "Heldout Test Set": "data/example-predictions-heldout.csv",
52
  }
53
- ANTIBODY_NAMES_DICT = {
54
- "GDPa1": pd.read_csv(EXAMPLE_FILE_DICT["GDPa1"])["antibody_name"].tolist(),
55
- "GDPa1_cross_validation": pd.read_csv(EXAMPLE_FILE_DICT["GDPa1_cross_validation"])[
56
- "antibody_name"
57
- ].tolist(),
58
- "Heldout Test Set": pd.read_csv(EXAMPLE_FILE_DICT["Heldout Test Set"])[
59
- "antibody_name"
60
- ].tolist(),
61
- }
62
 
63
  # Huggingface API
64
  TOKEN = os.environ.get("HF_TOKEN")
 
28
  "Tm2": "🌡️",
29
  "Titer": "🧪",
30
  }
31
+ ASSAY_HIGHER_IS_BETTER = {
32
+ "HIC": False,
33
+ "Tm2": True,
34
+ "Titer": True,
35
+ "PR_CHO": False,
36
+ "AC-SINS_pH7.4": False,
37
+ }
38
  # Tabs with emojis
39
  ABOUT_TAB_NAME = "📖 About / Rules"
40
  FAQ_TAB_NAME = "❓ FAQs"
 
57
  "GDPa1_cross_validation": "data/example-predictions-cv.csv",
58
  "Heldout Test Set": "data/example-predictions-heldout.csv",
59
  }
60
+ # GDPa1 dataset
61
+ GDPa1_path = "hf://datasets/ginkgo-datapoints/GDPa1/GDPa1_v1.2_20250814.csv"
 
 
 
 
 
 
 
62
 
63
  # Huggingface API
64
  TOKEN = os.environ.get("HF_TOKEN")
evaluation.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from scipy.stats import spearmanr
3
+ import pandas as pd
4
+ import numpy as np
5
+
6
+ from constants import ASSAY_LIST, ASSAY_HIGHER_IS_BETTER
7
+
8
+
9
+ FOLD_COL = "hierarchical_cluster_IgG_isotype_stratified_fold"
10
+
11
+
12
+ def recall_at_k(y_true: np.ndarray, y_pred: np.ndarray, frac: float = 0.1) -> float:
13
+ """Calculate recall (TP)/(TP+FN) for top fraction of true values.
14
+
15
+ A recall of 1 would mean that the top fraction of true values are also the top fraction of predicted values.
16
+ There is no penalty for ranking the top k differently.
17
+
18
+ Args:
19
+ y_true (np.ndarray): true values with shape (num_data,)
20
+ y_pred (np.ndarray): predicted values with shape (num_data,)
21
+ frac (float, optional): fraction of data points to consider as the top. Defaults to 0.1.
22
+
23
+ Returns:
24
+ float: recall at top k of data
25
+ """
26
+ top_k = int(len(y_true) * frac)
27
+ y_true, y_pred = np.array(y_true).flatten(), np.array(y_pred).flatten()
28
+ true_top_k = np.argsort(y_true)[-1 * top_k :]
29
+ predicted_top_k = np.argsort(y_pred)[-1 * top_k :]
30
+
31
+ return (
32
+ len(
33
+ set(list(true_top_k.flatten())).intersection(
34
+ set(list(predicted_top_k.flatten()))
35
+ )
36
+ )
37
+ / top_k
38
+ )
39
+
40
+
41
+ def get_metrics(
42
+ predictions_series: pd.Series, target_series: pd.Series, assay_col: str
43
+ ) -> dict[str, float]:
44
+ results_dict = {
45
+ "spearman": spearmanr(
46
+ predictions_series, target_series, nan_policy="omit"
47
+ ).correlation
48
+ }
49
+ # Top 10% recall
50
+ y_true = target_series.values
51
+ y_pred = predictions_series.values
52
+ if not ASSAY_HIGHER_IS_BETTER[assay_col]:
53
+ y_true = -1 * y_true
54
+ y_pred = -1 * y_pred
55
+ results_dict["top_10_recall"] = recall_at_k(y_true=y_true, y_pred=y_pred, frac=0.1)
56
+ return results_dict
57
+
58
+
59
+ def get_metrics_cross_validation(
60
+ predictions_series: pd.Series,
61
+ target_series: pd.Series,
62
+ folds_series: pd.Series,
63
+ assay_col: str,
64
+ ) -> dict[str, float]:
65
+ # Run evaluate in a cross-validation loop
66
+ results_dict = defaultdict(list)
67
+ if folds_series.nunique() != 5:
68
+ raise ValueError(f"Expected 5 folds, got {folds_series.nunique()}")
69
+ for fold in folds_series.unique():
70
+ predictions_series_fold = predictions_series[folds_series == fold]
71
+ target_series_fold = target_series[folds_series == fold]
72
+ results = get_metrics(predictions_series_fold, target_series_fold, assay_col)
73
+ # Update the results_dict with the results for this fold
74
+ for key, value in results.items():
75
+ results_dict[key].append(value)
76
+ # Calculate the mean of the results for each key (could also add std dev later)
77
+ for key, values in results_dict.items():
78
+ results_dict[key] = np.mean(values)
79
+ return results_dict
80
+
81
+
82
+ def _get_result_for_assay(df_merged, assay_col, dataset_name):
83
+ """
84
+ Return a dictionary with the results for a single assay.
85
+ """
86
+ if dataset_name == "GDPa1_cross_validation":
87
+ results = get_metrics_cross_validation(
88
+ df_merged[assay_col + "_pred"],
89
+ df_merged[assay_col + "_true"],
90
+ df_merged[FOLD_COL],
91
+ assay_col,
92
+ )
93
+ elif dataset_name == "GDPa1":
94
+ results = get_metrics(
95
+ df_merged[assay_col + "_pred"], df_merged[assay_col + "_true"], assay_col
96
+ )
97
+ elif dataset_name == "Heldout Test Set":
98
+ # Just record these as NaNs for now - they'll appear on the leaderboard and we can handle them on their own
99
+ results = {"spearman": np.nan, "top_10_recall": np.nan}
100
+ results["assay"] = assay_col
101
+ return results
102
+
103
+
104
+ def _get_error_result(assay_col, dataset_name, error):
105
+ """
106
+ Return a dictionary with the error message instead of metrics.
107
+ Used when _get_result_for_assay fails.
108
+ """
109
+ print(f"Error evaluating {assay_col}: {error}")
110
+ # Add a failed result record with error information
111
+ error_result = {
112
+ "dataset": dataset_name,
113
+ "assay": assay_col,
114
+ }
115
+
116
+ error_result.update({"spearman": error, "top_10_recall": error})
117
+ return error_result
118
+
119
+
120
+ def evaluate(predictions_df, target_df, dataset_name="GDPa1"):
121
+ """
122
+ Evaluates a single model, where the predictions dataframe has columns named by property.
123
+ eg. my_model.csv has columns antibody_name, HIC, Tm2
124
+ Lood: Copied from Github repo, which I should move over here
125
+ """
126
+ properties_in_preds = [
127
+ col for col in predictions_df.columns if col in ASSAY_LIST
128
+ ]
129
+ df_merged = pd.merge(
130
+ target_df[["antibody_name", FOLD_COL] + ASSAY_LIST],
131
+ predictions_df[["antibody_name"] + properties_in_preds],
132
+ on="antibody_name",
133
+ how="left",
134
+ suffixes=("_true", "_pred"),
135
+ )
136
+ results_list = []
137
+ # Process each property one by one for better error handling
138
+ for assay_col in properties_in_preds:
139
+ try:
140
+ results = _get_result_for_assay(
141
+ df_merged, assay_col, dataset_name
142
+ )
143
+ results_list.append(results)
144
+
145
+ except Exception as e:
146
+ error_result = _get_error_result(
147
+ assay_col, dataset_name, e
148
+ )
149
+ results_list.append(error_result)
150
+
151
+ results_df = pd.DataFrame(results_list)
152
+ return results_df
submit.py CHANGED
@@ -98,6 +98,8 @@ def make_submission(
98
  if path_obj.suffix.lower() != ".csv":
99
  raise gr.Error("File must be a CSV file. Please upload a .csv file.")
100
 
 
 
101
  upload_submission(
102
  file_path=path_obj,
103
  user_state=user_state,
 
98
  if path_obj.suffix.lower() != ".csv":
99
  raise gr.Error("File must be a CSV file. Please upload a .csv file.")
100
 
101
+
102
+
103
  upload_submission(
104
  file_path=path_obj,
105
  user_state=user_state,
validation.py CHANGED
@@ -7,8 +7,9 @@ from constants import (
7
  ASSAY_LIST,
8
  CV_COLUMN,
9
  EXAMPLE_FILE_DICT,
10
- ANTIBODY_NAMES_DICT,
11
  )
 
12
 
13
 
14
  def validate_username(username: str) -> bool:
@@ -137,6 +138,7 @@ def validate_cv_submission(
137
  raise gr.Error(
138
  f"❌ Fold assignments don't match canonical CV folds: {'; '.join(examples)}"
139
  )
 
140
 
141
 
142
  def validate_full_dataset_submission(df: pd.DataFrame) -> None:
@@ -202,9 +204,11 @@ def validate_dataframe(df: pd.DataFrame, submission_type: str = "GDPa1") -> None
202
  raise gr.Error(
203
  f"❌ CSV should have only one row per antibody. Found {n_duplicates} duplicates."
204
  )
 
 
205
  # All antibody names should be recognizable
206
  unrecognized_antibodies = set(df["antibody_name"]) - set(
207
- ANTIBODY_NAMES_DICT[submission_type]
208
  )
209
  if unrecognized_antibodies:
210
  raise gr.Error(
@@ -212,7 +216,8 @@ def validate_dataframe(df: pd.DataFrame, submission_type: str = "GDPa1") -> None
212
  )
213
 
214
  # All antibody names should be present
215
- missing_antibodies = set(ANTIBODY_NAMES_DICT[submission_type]) - set(
 
216
  df["antibody_name"]
217
  )
218
  if missing_antibodies:
@@ -224,6 +229,14 @@ def validate_dataframe(df: pd.DataFrame, submission_type: str = "GDPa1") -> None
224
  validate_cv_submission(df, submission_type)
225
  else: # full_dataset
226
  validate_full_dataset_submission(df)
 
 
 
 
 
 
 
 
227
 
228
 
229
  def validate_csv_file(file_content: str, submission_type: str = "GDPa1") -> None:
 
7
  ASSAY_LIST,
8
  CV_COLUMN,
9
  EXAMPLE_FILE_DICT,
10
+ GDPa1_path,
11
  )
12
+ from evaluation import evaluate
13
 
14
 
15
  def validate_username(username: str) -> bool:
 
138
  raise gr.Error(
139
  f"❌ Fold assignments don't match canonical CV folds: {'; '.join(examples)}"
140
  )
141
+
142
 
143
 
144
  def validate_full_dataset_submission(df: pd.DataFrame) -> None:
 
204
  raise gr.Error(
205
  f"❌ CSV should have only one row per antibody. Found {n_duplicates} duplicates."
206
  )
207
+
208
+ example_df = pd.read_csv(EXAMPLE_FILE_DICT[submission_type])
209
  # All antibody names should be recognizable
210
  unrecognized_antibodies = set(df["antibody_name"]) - set(
211
+ example_df["antibody_name"].tolist()
212
  )
213
  if unrecognized_antibodies:
214
  raise gr.Error(
 
216
  )
217
 
218
  # All antibody names should be present
219
+ # Note(Lood): Technically we could check that the antibodies are present just for the property that needs to be predicted
220
+ missing_antibodies = set(example_df["antibody_name"].tolist()) - set(
221
  df["antibody_name"]
222
  )
223
  if missing_antibodies:
 
229
  validate_cv_submission(df, submission_type)
230
  else: # full_dataset
231
  validate_full_dataset_submission(df)
232
+
233
+ # Check Spearman correlations on public set
234
+ df_gdpa1 = pd.read_csv(GDPa1_path)
235
+ if submission_type in ["GDPa1", "GDPa1_cross_validation"]:
236
+ results_df = evaluate(predictions_df=df, target_df=df_gdpa1, dataset_name=submission_type)
237
+ # Check that the Spearman correlations are not too high
238
+ if results_df["spearman"].max() > 0.9:
239
+ raise gr.Error(f"❌ Your submission shows abnormally high correlations (>0.9) on the public set. Please check that you're not overfitting on the public set and are using cross-validation if training a new model.\nIf you think this is a mistake, please contact [email protected].", duration=30)
240
 
241
 
242
  def validate_csv_file(file_content: str, submission_type: str = "GDPa1") -> None: