Minibase commited on
Commit
a363a7b
Β·
verified Β·
1 Parent(s): d16cb83

Upload run_benchmarks.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run_benchmarks.py +330 -0
run_benchmarks.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Minimal De-identification Benchmark Runner for HuggingFace Publication
4
+
5
+ This script evaluates a de-identification model's performance on key metrics:
6
+ - PII Detection Rate: How well it identifies personal identifiers
7
+ - Completeness: Whether all PII is successfully masked
8
+ - Semantic Preservation: How well meaning is preserved
9
+ - Latency: Response time performance
10
+ - Domain Performance: Results across different text types
11
+ """
12
+
13
+ import json
14
+ import re
15
+ import time
16
+ import requests
17
+ from typing import Dict, List, Tuple, Any
18
+ import yaml
19
+ from datetime import datetime
20
+ import sys
21
+ import os
22
+
23
+ class DeIdBenchmarkRunner:
24
+ def __init__(self, config_path: str):
25
+ with open(config_path, 'r') as f:
26
+ self.config = yaml.safe_load(f)
27
+
28
+ self.results = {
29
+ "metadata": {
30
+ "timestamp": datetime.now().isoformat(),
31
+ "model": "Minibase-DeId-Small",
32
+ "dataset": self.config["datasets"]["benchmark_dataset"]["file_path"],
33
+ "sample_size": self.config["datasets"]["benchmark_dataset"]["sample_size"]
34
+ },
35
+ "metrics": {},
36
+ "domain_performance": {},
37
+ "examples": []
38
+ }
39
+
40
+ def load_dataset(self) -> List[Dict]:
41
+ """Load and sample the benchmark dataset"""
42
+ dataset_path = self.config["datasets"]["benchmark_dataset"]["file_path"]
43
+ sample_size = self.config["datasets"]["benchmark_dataset"]["sample_size"]
44
+
45
+ examples = []
46
+ with open(dataset_path, 'r') as f:
47
+ for i, line in enumerate(f):
48
+ if i >= sample_size:
49
+ break
50
+ examples.append(json.loads(line.strip()))
51
+
52
+ print(f"βœ… Loaded {len(examples)} examples from {dataset_path}")
53
+ return examples
54
+
55
+ def categorize_domain(self, text: str) -> str:
56
+ """Categorize text by domain based on keywords"""
57
+ text_lower = text.lower()
58
+
59
+ for domain, info in self.config["metrics"]["domain_performance"].items():
60
+ if any(keyword in text_lower for keyword in info["keywords"]):
61
+ return domain
62
+
63
+ return "general"
64
+
65
+ def extract_placeholders(self, text: str) -> List[str]:
66
+ """Extract all placeholder tags from text (e.g., [NAME_1], [DOB_1])"""
67
+ # Match patterns like [WORD_1], [WORD_NUMBER], etc.
68
+ pattern = r'\[([A-Z_]+_\d+)\]'
69
+ return re.findall(pattern, text)
70
+
71
+ def calculate_pii_detection_rate(self, predicted: str, expected: str) -> float:
72
+ """Calculate how many expected PII elements were detected"""
73
+ expected_placeholders = set(self.extract_placeholders(expected))
74
+
75
+ if not expected_placeholders:
76
+ return 1.0 # No PII to detect
77
+
78
+ predicted_placeholders = set(self.extract_placeholders(predicted))
79
+
80
+ # Calculate overlap
81
+ detected = len(expected_placeholders.intersection(predicted_placeholders))
82
+ return detected / len(expected_placeholders)
83
+
84
+ def calculate_completeness(self, predicted: str) -> bool:
85
+ """Check if response appears to have no obvious PII remaining"""
86
+ # Simple heuristics for detecting remaining PII
87
+ pii_patterns = [
88
+ r'\b\d{4}-\d{2}-\d{2}\b', # Dates like 1985-03-15
89
+ r'\b\d{1,3}\s+[A-Z][a-z]+\s+(?:St|Street|Ave|Avenue|Rd|Road)\b', # Addresses
90
+ r'\(\d{3}\)\s*\d{3}-\d{4}\b', # Phone numbers
91
+ r'\b[A-Z][a-z]+\s+[A-Z][a-z]+\b', # Names (simplified)
92
+ r'\b\d+@\w+\.\w+\b' # Email addresses
93
+ ]
94
+
95
+ # If any PII patterns remain, it's incomplete
96
+ for pattern in pii_patterns:
97
+ if re.search(pattern, predicted):
98
+ return False
99
+
100
+ return True
101
+
102
+ def calculate_semantic_preservation(self, predicted: str, expected: str) -> float:
103
+ """Calculate semantic preservation based on placeholder structure"""
104
+ # Simple similarity: compare placeholder types and counts
105
+ pred_placeholders = self.extract_placeholders(predicted)
106
+ expected_placeholders = self.extract_placeholders(expected)
107
+
108
+ if not expected_placeholders:
109
+ return 1.0
110
+
111
+ # Count placeholder types
112
+ def count_types(placeholders):
113
+ types = {}
114
+ for ph in placeholders:
115
+ # Extract type (e.g., "NAME" from "NAME_1")
116
+ ptype = ph.split('_')[0]
117
+ types[ptype] = types.get(ptype, 0) + 1
118
+ return types
119
+
120
+ pred_types = count_types(pred_placeholders)
121
+ expected_types = count_types(expected_placeholders)
122
+
123
+ # Calculate similarity based on type distribution
124
+ all_types = set(pred_types.keys()) | set(expected_types.keys())
125
+ similarity = 0
126
+
127
+ for ptype in all_types:
128
+ pred_count = pred_types.get(ptype, 0)
129
+ exp_count = expected_types.get(ptype, 0)
130
+ if exp_count > 0:
131
+ similarity += min(pred_count, exp_count) / exp_count
132
+
133
+ return similarity / len(all_types) if all_types else 1.0
134
+
135
+ def call_model(self, instruction: str, input_text: str) -> Tuple[str, float]:
136
+ """Call the de-identification model and measure latency"""
137
+ prompt = f"{instruction}\n\nInput: {input_text}\n\nResponse: "
138
+
139
+ payload = {
140
+ "prompt": prompt,
141
+ "max_tokens": self.config["model"]["max_tokens"],
142
+ "temperature": self.config["model"]["temperature"]
143
+ }
144
+
145
+ headers = {'Content-Type': 'application/json'}
146
+
147
+ start_time = time.time()
148
+ try:
149
+ response = requests.post(
150
+ f"{self.config['model']['base_url']}/completion",
151
+ json=payload,
152
+ headers=headers,
153
+ timeout=self.config["model"]["timeout"]
154
+ )
155
+ latency = (time.time() - start_time) * 1000 # Convert to ms
156
+
157
+ if response.status_code == 200:
158
+ result = response.json()
159
+ return result.get('content', ''), latency
160
+ else:
161
+ return f"Error: Server returned status {response.status_code}", latency
162
+ except requests.exceptions.RequestException as e:
163
+ latency = (time.time() - start_time) * 1000
164
+ return f"Error: {e}", latency
165
+
166
+ def run_benchmarks(self):
167
+ """Run the complete benchmark suite"""
168
+ print("πŸš€ Starting De-identification Benchmarks...")
169
+ print(f"πŸ“Š Sample size: {self.config['datasets']['benchmark_dataset']['sample_size']}")
170
+ print(f"🎯 Model: {self.results['metadata']['model']}")
171
+ print()
172
+
173
+ examples = self.load_dataset()
174
+
175
+ # Initialize metrics
176
+ total_pii_detection = 0
177
+ total_completeness = 0
178
+ total_semantic_preservation = 0
179
+ total_latency = 0
180
+ domain_counts = {}
181
+ domain_metrics = {}
182
+
183
+ successful_requests = 0
184
+
185
+ for i, example in enumerate(examples):
186
+ if i % 10 == 0:
187
+ print(f"πŸ“ˆ Progress: {i}/{len(examples)} examples processed")
188
+
189
+ instruction = example[self.config["datasets"]["benchmark_dataset"]["instruction_field"]]
190
+ input_text = example[self.config["datasets"]["benchmark_dataset"]["input_field"]]
191
+ expected_output = example[self.config["datasets"]["benchmark_dataset"]["expected_output_field"]]
192
+
193
+ # Categorize domain
194
+ domain = self.categorize_domain(input_text)
195
+ domain_counts[domain] = domain_counts.get(domain, 0) + 1
196
+
197
+ # Call model
198
+ predicted_output, latency = self.call_model(instruction, input_text)
199
+
200
+ if not predicted_output.startswith("Error"):
201
+ successful_requests += 1
202
+
203
+ # Calculate metrics
204
+ pii_detection = self.calculate_pii_detection_rate(predicted_output, expected_output)
205
+ completeness = self.calculate_completeness(predicted_output)
206
+ semantic_preservation = self.calculate_semantic_preservation(predicted_output, expected_output)
207
+
208
+ # Update totals
209
+ total_pii_detection += pii_detection
210
+ total_completeness += completeness
211
+ total_semantic_preservation += semantic_preservation
212
+ total_latency += latency
213
+
214
+ # Update domain metrics
215
+ if domain not in domain_metrics:
216
+ domain_metrics[domain] = {"pii_detection": 0, "completeness": 0, "semantic": 0, "count": 0}
217
+
218
+ domain_metrics[domain]["pii_detection"] += pii_detection
219
+ domain_metrics[domain]["completeness"] += completeness
220
+ domain_metrics[domain]["semantic"] += semantic_preservation
221
+ domain_metrics[domain]["count"] += 1
222
+
223
+ # Store example if requested
224
+ if len(self.results["examples"]) < self.config["output"]["max_examples"]:
225
+ self.results["examples"].append({
226
+ "input": input_text,
227
+ "expected": expected_output,
228
+ "predicted": predicted_output,
229
+ "domain": domain,
230
+ "metrics": {
231
+ "pii_detection": pii_detection,
232
+ "completeness": completeness,
233
+ "semantic_preservation": semantic_preservation,
234
+ "latency_ms": latency
235
+ }
236
+ })
237
+
238
+ # Calculate final metrics
239
+ if successful_requests > 0:
240
+ self.results["metrics"] = {
241
+ "pii_detection_rate": total_pii_detection / successful_requests,
242
+ "completeness_score": total_completeness / successful_requests,
243
+ "semantic_preservation": total_semantic_preservation / successful_requests,
244
+ "average_latency_ms": total_latency / successful_requests,
245
+ "successful_requests": successful_requests,
246
+ "total_requests": len(examples)
247
+ }
248
+
249
+ # Calculate domain performance
250
+ for domain, metrics in domain_metrics.items():
251
+ count = metrics["count"]
252
+ self.results["domain_performance"][domain] = {
253
+ "sample_count": count,
254
+ "pii_detection_rate": metrics["pii_detection"] / count,
255
+ "completeness_score": metrics["completeness"] / count,
256
+ "semantic_preservation": metrics["semantic"] / count
257
+ }
258
+
259
+ self.save_results()
260
+
261
+ def save_results(self):
262
+ """Save benchmark results to files"""
263
+ # Save detailed JSON results
264
+ with open(self.config["output"]["detailed_results_file"], 'w') as f:
265
+ json.dump(self.results, f, indent=2)
266
+
267
+ # Save human-readable summary
268
+ summary = self.generate_summary()
269
+ with open(self.config["output"]["results_file"], 'w') as f:
270
+ f.write(summary)
271
+
272
+ print("\nβœ… Benchmark complete!")
273
+ print(f"πŸ“„ Detailed results saved to: {self.config['output']['detailed_results_file']}")
274
+ print(f"πŸ“Š Summary saved to: {self.config['output']['results_file']}")
275
+
276
+ def generate_summary(self) -> str:
277
+ """Generate a human-readable benchmark summary"""
278
+ m = self.results["metrics"]
279
+
280
+ summary = f"""# De-identification Benchmark Results
281
+ **Model:** {self.results['metadata']['model']}
282
+ **Dataset:** {self.results['metadata']['dataset']}
283
+ **Sample Size:** {self.results['metadata']['sample_size']}
284
+ **Date:** {self.results['metadata']['timestamp']}
285
+
286
+ ## Overall Performance
287
+
288
+ | Metric | Score | Description |
289
+ |--------|-------|-------------|
290
+ | PII Detection Rate | {m.get('pii_detection_rate', 0):.3f} | How well personal identifiers are detected |
291
+ | Completeness Score | {m.get('completeness_score', 0):.3f} | Percentage of texts fully de-identified |
292
+ | Semantic Preservation | {m.get('semantic_preservation', 0):.3f} | How well meaning is preserved |
293
+ | Average Latency | {m.get('average_latency_ms', 0):.1f}ms | Response time performance |
294
+
295
+ ## Domain Performance
296
+
297
+ """
298
+
299
+ for domain, metrics in self.results["domain_performance"].items():
300
+ summary += f"### {domain.title()} Domain ({metrics['sample_count']} samples)\n"
301
+ summary += f"- PII Detection: {metrics['pii_detection_rate']:.3f}\n"
302
+ summary += f"- Completeness: {metrics['completeness_score']:.3f}\n"
303
+ summary += f"- Semantic Preservation: {metrics['semantic_preservation']:.3f}\n\n"
304
+
305
+ if self.config["output"]["include_examples"] and self.results["examples"]:
306
+ summary += "## Example Results\n\n"
307
+ for i, example in enumerate(self.results["examples"][:3]): # Show first 3 examples
308
+ summary += f"### Example {i+1} ({example['domain']} domain)\n"
309
+ summary += f"**Input:** {example['input'][:100]}...\n"
310
+ summary += f"**Expected:** {example['expected'][:100]}...\n"
311
+ summary += f"**Predicted:** {example['predicted'][:100]}...\n"
312
+ summary += f"**PII Detection:** {example['metrics']['pii_detection']:.3f}\n\n"
313
+
314
+ return summary
315
+
316
+ def main():
317
+ if len(sys.argv) != 2:
318
+ print("Usage: python run_benchmarks.py <config_file>")
319
+ sys.exit(1)
320
+
321
+ config_path = sys.argv[1]
322
+ if not os.path.exists(config_path):
323
+ print(f"Error: Config file {config_path} not found")
324
+ sys.exit(1)
325
+
326
+ runner = DeIdBenchmarkRunner(config_path)
327
+ runner.run_benchmarks()
328
+
329
+ if __name__ == "__main__":
330
+ main()