DeId-Small / run_benchmarks.py
Minibase's picture
Upload run_benchmarks.py with huggingface_hub
a363a7b verified
raw
history blame
13.6 kB
#!/usr/bin/env python3
"""
Minimal De-identification Benchmark Runner for HuggingFace Publication
This script evaluates a de-identification model's performance on key metrics:
- PII Detection Rate: How well it identifies personal identifiers
- Completeness: Whether all PII is successfully masked
- Semantic Preservation: How well meaning is preserved
- Latency: Response time performance
- Domain Performance: Results across different text types
"""
import json
import re
import time
import requests
from typing import Dict, List, Tuple, Any
import yaml
from datetime import datetime
import sys
import os
class DeIdBenchmarkRunner:
def __init__(self, config_path: str):
with open(config_path, 'r') as f:
self.config = yaml.safe_load(f)
self.results = {
"metadata": {
"timestamp": datetime.now().isoformat(),
"model": "Minibase-DeId-Small",
"dataset": self.config["datasets"]["benchmark_dataset"]["file_path"],
"sample_size": self.config["datasets"]["benchmark_dataset"]["sample_size"]
},
"metrics": {},
"domain_performance": {},
"examples": []
}
def load_dataset(self) -> List[Dict]:
"""Load and sample the benchmark dataset"""
dataset_path = self.config["datasets"]["benchmark_dataset"]["file_path"]
sample_size = self.config["datasets"]["benchmark_dataset"]["sample_size"]
examples = []
with open(dataset_path, 'r') as f:
for i, line in enumerate(f):
if i >= sample_size:
break
examples.append(json.loads(line.strip()))
print(f"โœ… Loaded {len(examples)} examples from {dataset_path}")
return examples
def categorize_domain(self, text: str) -> str:
"""Categorize text by domain based on keywords"""
text_lower = text.lower()
for domain, info in self.config["metrics"]["domain_performance"].items():
if any(keyword in text_lower for keyword in info["keywords"]):
return domain
return "general"
def extract_placeholders(self, text: str) -> List[str]:
"""Extract all placeholder tags from text (e.g., [NAME_1], [DOB_1])"""
# Match patterns like [WORD_1], [WORD_NUMBER], etc.
pattern = r'\[([A-Z_]+_\d+)\]'
return re.findall(pattern, text)
def calculate_pii_detection_rate(self, predicted: str, expected: str) -> float:
"""Calculate how many expected PII elements were detected"""
expected_placeholders = set(self.extract_placeholders(expected))
if not expected_placeholders:
return 1.0 # No PII to detect
predicted_placeholders = set(self.extract_placeholders(predicted))
# Calculate overlap
detected = len(expected_placeholders.intersection(predicted_placeholders))
return detected / len(expected_placeholders)
def calculate_completeness(self, predicted: str) -> bool:
"""Check if response appears to have no obvious PII remaining"""
# Simple heuristics for detecting remaining PII
pii_patterns = [
r'\b\d{4}-\d{2}-\d{2}\b', # Dates like 1985-03-15
r'\b\d{1,3}\s+[A-Z][a-z]+\s+(?:St|Street|Ave|Avenue|Rd|Road)\b', # Addresses
r'\(\d{3}\)\s*\d{3}-\d{4}\b', # Phone numbers
r'\b[A-Z][a-z]+\s+[A-Z][a-z]+\b', # Names (simplified)
r'\b\d+@\w+\.\w+\b' # Email addresses
]
# If any PII patterns remain, it's incomplete
for pattern in pii_patterns:
if re.search(pattern, predicted):
return False
return True
def calculate_semantic_preservation(self, predicted: str, expected: str) -> float:
"""Calculate semantic preservation based on placeholder structure"""
# Simple similarity: compare placeholder types and counts
pred_placeholders = self.extract_placeholders(predicted)
expected_placeholders = self.extract_placeholders(expected)
if not expected_placeholders:
return 1.0
# Count placeholder types
def count_types(placeholders):
types = {}
for ph in placeholders:
# Extract type (e.g., "NAME" from "NAME_1")
ptype = ph.split('_')[0]
types[ptype] = types.get(ptype, 0) + 1
return types
pred_types = count_types(pred_placeholders)
expected_types = count_types(expected_placeholders)
# Calculate similarity based on type distribution
all_types = set(pred_types.keys()) | set(expected_types.keys())
similarity = 0
for ptype in all_types:
pred_count = pred_types.get(ptype, 0)
exp_count = expected_types.get(ptype, 0)
if exp_count > 0:
similarity += min(pred_count, exp_count) / exp_count
return similarity / len(all_types) if all_types else 1.0
def call_model(self, instruction: str, input_text: str) -> Tuple[str, float]:
"""Call the de-identification model and measure latency"""
prompt = f"{instruction}\n\nInput: {input_text}\n\nResponse: "
payload = {
"prompt": prompt,
"max_tokens": self.config["model"]["max_tokens"],
"temperature": self.config["model"]["temperature"]
}
headers = {'Content-Type': 'application/json'}
start_time = time.time()
try:
response = requests.post(
f"{self.config['model']['base_url']}/completion",
json=payload,
headers=headers,
timeout=self.config["model"]["timeout"]
)
latency = (time.time() - start_time) * 1000 # Convert to ms
if response.status_code == 200:
result = response.json()
return result.get('content', ''), latency
else:
return f"Error: Server returned status {response.status_code}", latency
except requests.exceptions.RequestException as e:
latency = (time.time() - start_time) * 1000
return f"Error: {e}", latency
def run_benchmarks(self):
"""Run the complete benchmark suite"""
print("๐Ÿš€ Starting De-identification Benchmarks...")
print(f"๐Ÿ“Š Sample size: {self.config['datasets']['benchmark_dataset']['sample_size']}")
print(f"๐ŸŽฏ Model: {self.results['metadata']['model']}")
print()
examples = self.load_dataset()
# Initialize metrics
total_pii_detection = 0
total_completeness = 0
total_semantic_preservation = 0
total_latency = 0
domain_counts = {}
domain_metrics = {}
successful_requests = 0
for i, example in enumerate(examples):
if i % 10 == 0:
print(f"๐Ÿ“ˆ Progress: {i}/{len(examples)} examples processed")
instruction = example[self.config["datasets"]["benchmark_dataset"]["instruction_field"]]
input_text = example[self.config["datasets"]["benchmark_dataset"]["input_field"]]
expected_output = example[self.config["datasets"]["benchmark_dataset"]["expected_output_field"]]
# Categorize domain
domain = self.categorize_domain(input_text)
domain_counts[domain] = domain_counts.get(domain, 0) + 1
# Call model
predicted_output, latency = self.call_model(instruction, input_text)
if not predicted_output.startswith("Error"):
successful_requests += 1
# Calculate metrics
pii_detection = self.calculate_pii_detection_rate(predicted_output, expected_output)
completeness = self.calculate_completeness(predicted_output)
semantic_preservation = self.calculate_semantic_preservation(predicted_output, expected_output)
# Update totals
total_pii_detection += pii_detection
total_completeness += completeness
total_semantic_preservation += semantic_preservation
total_latency += latency
# Update domain metrics
if domain not in domain_metrics:
domain_metrics[domain] = {"pii_detection": 0, "completeness": 0, "semantic": 0, "count": 0}
domain_metrics[domain]["pii_detection"] += pii_detection
domain_metrics[domain]["completeness"] += completeness
domain_metrics[domain]["semantic"] += semantic_preservation
domain_metrics[domain]["count"] += 1
# Store example if requested
if len(self.results["examples"]) < self.config["output"]["max_examples"]:
self.results["examples"].append({
"input": input_text,
"expected": expected_output,
"predicted": predicted_output,
"domain": domain,
"metrics": {
"pii_detection": pii_detection,
"completeness": completeness,
"semantic_preservation": semantic_preservation,
"latency_ms": latency
}
})
# Calculate final metrics
if successful_requests > 0:
self.results["metrics"] = {
"pii_detection_rate": total_pii_detection / successful_requests,
"completeness_score": total_completeness / successful_requests,
"semantic_preservation": total_semantic_preservation / successful_requests,
"average_latency_ms": total_latency / successful_requests,
"successful_requests": successful_requests,
"total_requests": len(examples)
}
# Calculate domain performance
for domain, metrics in domain_metrics.items():
count = metrics["count"]
self.results["domain_performance"][domain] = {
"sample_count": count,
"pii_detection_rate": metrics["pii_detection"] / count,
"completeness_score": metrics["completeness"] / count,
"semantic_preservation": metrics["semantic"] / count
}
self.save_results()
def save_results(self):
"""Save benchmark results to files"""
# Save detailed JSON results
with open(self.config["output"]["detailed_results_file"], 'w') as f:
json.dump(self.results, f, indent=2)
# Save human-readable summary
summary = self.generate_summary()
with open(self.config["output"]["results_file"], 'w') as f:
f.write(summary)
print("\nโœ… Benchmark complete!")
print(f"๐Ÿ“„ Detailed results saved to: {self.config['output']['detailed_results_file']}")
print(f"๐Ÿ“Š Summary saved to: {self.config['output']['results_file']}")
def generate_summary(self) -> str:
"""Generate a human-readable benchmark summary"""
m = self.results["metrics"]
summary = f"""# De-identification Benchmark Results
**Model:** {self.results['metadata']['model']}
**Dataset:** {self.results['metadata']['dataset']}
**Sample Size:** {self.results['metadata']['sample_size']}
**Date:** {self.results['metadata']['timestamp']}
## Overall Performance
| Metric | Score | Description |
|--------|-------|-------------|
| PII Detection Rate | {m.get('pii_detection_rate', 0):.3f} | How well personal identifiers are detected |
| Completeness Score | {m.get('completeness_score', 0):.3f} | Percentage of texts fully de-identified |
| Semantic Preservation | {m.get('semantic_preservation', 0):.3f} | How well meaning is preserved |
| Average Latency | {m.get('average_latency_ms', 0):.1f}ms | Response time performance |
## Domain Performance
"""
for domain, metrics in self.results["domain_performance"].items():
summary += f"### {domain.title()} Domain ({metrics['sample_count']} samples)\n"
summary += f"- PII Detection: {metrics['pii_detection_rate']:.3f}\n"
summary += f"- Completeness: {metrics['completeness_score']:.3f}\n"
summary += f"- Semantic Preservation: {metrics['semantic_preservation']:.3f}\n\n"
if self.config["output"]["include_examples"] and self.results["examples"]:
summary += "## Example Results\n\n"
for i, example in enumerate(self.results["examples"][:3]): # Show first 3 examples
summary += f"### Example {i+1} ({example['domain']} domain)\n"
summary += f"**Input:** {example['input'][:100]}...\n"
summary += f"**Expected:** {example['expected'][:100]}...\n"
summary += f"**Predicted:** {example['predicted'][:100]}...\n"
summary += f"**PII Detection:** {example['metrics']['pii_detection']:.3f}\n\n"
return summary
def main():
if len(sys.argv) != 2:
print("Usage: python run_benchmarks.py <config_file>")
sys.exit(1)
config_path = sys.argv[1]
if not os.path.exists(config_path):
print(f"Error: Config file {config_path} not found")
sys.exit(1)
runner = DeIdBenchmarkRunner(config_path)
runner.run_benchmarks()
if __name__ == "__main__":
main()