Simon Clematide commited on
Commit
fa342d2
·
1 Parent(s): 4d23e7a

Add CLI script for processing JSONL files and generating binary predictions with optional Excel output

Browse files
Files changed (1) hide show
  1. sdg_predict/cli_conversion.py +86 -0
sdg_predict/cli_conversion.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import pandas as pd
4
+ import logging
5
+
6
+
7
+ def binary_from_softmax(prediction, cap_class0=0.5):
8
+ """
9
+ Given a softmax-style prediction list, computes binary scores
10
+ for all non-class-0 labels, contrasted against (possibly capped) class-0 score.
11
+
12
+ Args:
13
+ prediction: list of {"label": str, "score": float}
14
+ cap_class0: float, maximum score allowed for label "0"
15
+
16
+ Returns:
17
+ dict of {label: binary_score}
18
+ """
19
+ score_0 = next((x["score"] for x in prediction if x["label"] == "0"), 0.0)
20
+ score_0 = min(score_0, cap_class0)
21
+
22
+ binary_predictions = {}
23
+ for entry in prediction:
24
+ label = entry["label"]
25
+ if label == "0":
26
+ continue
27
+ score = entry["score"]
28
+ binary_score = score / (score + score_0) if (score + score_0) > 0 else 0.0
29
+ binary_predictions[label] = round(binary_score, 3)
30
+
31
+ return binary_predictions
32
+
33
+
34
+ def process_jsonl(input_file, output_file, cap_class0, excel_file=None):
35
+ transformed_data = []
36
+ with open(input_file, "r") as infile, open(output_file, "w") as outfile:
37
+ for line in infile:
38
+ entry = json.loads(line)
39
+ prediction = entry.get("prediction", [])
40
+ entry["binary_predictions"] = binary_from_softmax(prediction, cap_class0)
41
+ outfile.write(json.dumps(entry, ensure_ascii=False) + "\n")
42
+
43
+ # Prepare data for Excel output
44
+ transformed_row = {
45
+ "publication_zora_id": entry.get("id"),
46
+ **{
47
+ f"dvdblk_sdg{sdg}": entry["binary_predictions"].get(str(sdg), 0)
48
+ for sdg in range(1, 18)
49
+ },
50
+ }
51
+ transformed_data.append(transformed_row)
52
+
53
+ if excel_file:
54
+ if not excel_file.endswith(".xlsx"):
55
+ raise ValueError("Excel file must have the .xlsx extension")
56
+ logging.info("Writing Excel output to %s", excel_file)
57
+ df_transformed = pd.DataFrame(transformed_data)
58
+ df_transformed.to_excel(excel_file, index=False)
59
+ logging.info("Excel output written to %s", excel_file)
60
+
61
+
62
+ def main():
63
+ parser = argparse.ArgumentParser(
64
+ description="Process JSONL file and compute binary predictions."
65
+ )
66
+ parser.add_argument("input_file", type=str, help="Path to the input JSONL file.")
67
+ parser.add_argument("output_file", type=str, help="Path to the output JSONL file.")
68
+ parser.add_argument(
69
+ "--cap_class0",
70
+ type=float,
71
+ default=0.5,
72
+ help="Maximum score allowed for class 0.",
73
+ )
74
+ parser.add_argument(
75
+ "--excel",
76
+ type=str,
77
+ help="Path to the Excel file for binary predictions (optional).",
78
+ )
79
+
80
+ args = parser.parse_args()
81
+
82
+ process_jsonl(args.input_file, args.output_file, args.cap_class0, args.excel)
83
+
84
+
85
+ if __name__ == "__main__":
86
+ main()