Upload run_benchmarks.py with huggingface_hub
Browse files- run_benchmarks.py +40 -53
    	
        run_benchmarks.py
    CHANGED
    
    | @@ -52,15 +52,7 @@ class DeIdBenchmarkRunner: | |
| 52 | 
             
                    print(f"✅ Loaded {len(examples)} examples from {dataset_path}")
         | 
| 53 | 
             
                    return examples
         | 
| 54 |  | 
| 55 | 
            -
                 | 
| 56 | 
            -
                    """Categorize text by domain based on keywords"""
         | 
| 57 | 
            -
                    text_lower = text.lower()
         | 
| 58 | 
            -
             | 
| 59 | 
            -
                    for domain, info in self.config["metrics"]["domain_performance"].items():
         | 
| 60 | 
            -
                        if any(keyword in text_lower for keyword in info["keywords"]):
         | 
| 61 | 
            -
                            return domain
         | 
| 62 | 
            -
             | 
| 63 | 
            -
                    return "general"
         | 
| 64 |  | 
| 65 | 
             
                def extract_placeholders(self, text: str) -> List[str]:
         | 
| 66 | 
             
                    """Extract all placeholder tags from text (e.g., [NAME_1], [DOB_1])"""
         | 
| @@ -68,18 +60,41 @@ class DeIdBenchmarkRunner: | |
| 68 | 
             
                    pattern = r'\[([A-Z_]+_\d+)\]'
         | 
| 69 | 
             
                    return re.findall(pattern, text)
         | 
| 70 |  | 
| 71 | 
            -
                def calculate_pii_detection_rate(self,  | 
| 72 | 
            -
                    """Calculate  | 
| 73 | 
            -
                     | 
|  | |
| 74 |  | 
| 75 | 
            -
                    if not  | 
| 76 | 
            -
                        return 1.0  # No PII  | 
|  | |
|  | |
|  | |
|  | |
| 77 |  | 
| 78 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 79 |  | 
| 80 | 
            -
                     | 
| 81 | 
            -
                    detected = len(expected_placeholders.intersection(predicted_placeholders))
         | 
| 82 | 
            -
                    return detected / len(expected_placeholders)
         | 
| 83 |  | 
| 84 | 
             
                def calculate_completeness(self, predicted: str) -> bool:
         | 
| 85 | 
             
                    """Check if response appears to have no obvious PII remaining"""
         | 
| @@ -177,8 +192,6 @@ class DeIdBenchmarkRunner: | |
| 177 | 
             
                    total_completeness = 0
         | 
| 178 | 
             
                    total_semantic_preservation = 0
         | 
| 179 | 
             
                    total_latency = 0
         | 
| 180 | 
            -
                    domain_counts = {}
         | 
| 181 | 
            -
                    domain_metrics = {}
         | 
| 182 |  | 
| 183 | 
             
                    successful_requests = 0
         | 
| 184 |  | 
| @@ -190,10 +203,6 @@ class DeIdBenchmarkRunner: | |
| 190 | 
             
                        input_text = example[self.config["datasets"]["benchmark_dataset"]["input_field"]]
         | 
| 191 | 
             
                        expected_output = example[self.config["datasets"]["benchmark_dataset"]["expected_output_field"]]
         | 
| 192 |  | 
| 193 | 
            -
                        # Categorize domain
         | 
| 194 | 
            -
                        domain = self.categorize_domain(input_text)
         | 
| 195 | 
            -
                        domain_counts[domain] = domain_counts.get(domain, 0) + 1
         | 
| 196 | 
            -
             | 
| 197 | 
             
                        # Call model
         | 
| 198 | 
             
                        predicted_output, latency = self.call_model(instruction, input_text)
         | 
| 199 |  | 
| @@ -201,7 +210,7 @@ class DeIdBenchmarkRunner: | |
| 201 | 
             
                            successful_requests += 1
         | 
| 202 |  | 
| 203 | 
             
                            # Calculate metrics
         | 
| 204 | 
            -
                            pii_detection = self.calculate_pii_detection_rate( | 
| 205 | 
             
                            completeness = self.calculate_completeness(predicted_output)
         | 
| 206 | 
             
                            semantic_preservation = self.calculate_semantic_preservation(predicted_output, expected_output)
         | 
| 207 |  | 
| @@ -211,22 +220,12 @@ class DeIdBenchmarkRunner: | |
| 211 | 
             
                            total_semantic_preservation += semantic_preservation
         | 
| 212 | 
             
                            total_latency += latency
         | 
| 213 |  | 
| 214 | 
            -
                            # Update domain metrics
         | 
| 215 | 
            -
                            if domain not in domain_metrics:
         | 
| 216 | 
            -
                                domain_metrics[domain] = {"pii_detection": 0, "completeness": 0, "semantic": 0, "count": 0}
         | 
| 217 | 
            -
             | 
| 218 | 
            -
                            domain_metrics[domain]["pii_detection"] += pii_detection
         | 
| 219 | 
            -
                            domain_metrics[domain]["completeness"] += completeness
         | 
| 220 | 
            -
                            domain_metrics[domain]["semantic"] += semantic_preservation
         | 
| 221 | 
            -
                            domain_metrics[domain]["count"] += 1
         | 
| 222 | 
            -
             | 
| 223 | 
             
                            # Store example if requested
         | 
| 224 | 
             
                            if len(self.results["examples"]) < self.config["output"]["max_examples"]:
         | 
| 225 | 
             
                                self.results["examples"].append({
         | 
| 226 | 
             
                                    "input": input_text,
         | 
| 227 | 
             
                                    "expected": expected_output,
         | 
| 228 | 
             
                                    "predicted": predicted_output,
         | 
| 229 | 
            -
                                    "domain": domain,
         | 
| 230 | 
             
                                    "metrics": {
         | 
| 231 | 
             
                                        "pii_detection": pii_detection,
         | 
| 232 | 
             
                                        "completeness": completeness,
         | 
| @@ -246,16 +245,6 @@ class DeIdBenchmarkRunner: | |
| 246 | 
             
                            "total_requests": len(examples)
         | 
| 247 | 
             
                        }
         | 
| 248 |  | 
| 249 | 
            -
                        # Calculate domain performance
         | 
| 250 | 
            -
                        for domain, metrics in domain_metrics.items():
         | 
| 251 | 
            -
                            count = metrics["count"]
         | 
| 252 | 
            -
                            self.results["domain_performance"][domain] = {
         | 
| 253 | 
            -
                                "sample_count": count,
         | 
| 254 | 
            -
                                "pii_detection_rate": metrics["pii_detection"] / count,
         | 
| 255 | 
            -
                                "completeness_score": metrics["completeness"] / count,
         | 
| 256 | 
            -
                                "semantic_preservation": metrics["semantic"] / count
         | 
| 257 | 
            -
                            }
         | 
| 258 | 
            -
             | 
| 259 | 
             
                    self.save_results()
         | 
| 260 |  | 
| 261 | 
             
                def save_results(self):
         | 
| @@ -292,20 +281,18 @@ class DeIdBenchmarkRunner: | |
| 292 | 
             
            | Semantic Preservation | {m.get('semantic_preservation', 0):.3f} | How well meaning is preserved |
         | 
| 293 | 
             
            | Average Latency | {m.get('average_latency_ms', 0):.1f}ms | Response time performance |
         | 
| 294 |  | 
| 295 | 
            -
            ##  | 
| 296 |  | 
| 297 | 
            -
             | 
|  | |
|  | |
| 298 |  | 
| 299 | 
            -
             | 
| 300 | 
            -
                        summary += f"### {domain.title()} Domain ({metrics['sample_count']} samples)\n"
         | 
| 301 | 
            -
                        summary += f"- PII Detection: {metrics['pii_detection_rate']:.3f}\n"
         | 
| 302 | 
            -
                        summary += f"- Completeness: {metrics['completeness_score']:.3f}\n"
         | 
| 303 | 
            -
                        summary += f"- Semantic Preservation: {metrics['semantic_preservation']:.3f}\n\n"
         | 
| 304 |  | 
| 305 | 
             
                    if self.config["output"]["include_examples"] and self.results["examples"]:
         | 
| 306 | 
             
                        summary += "## Example Results\n\n"
         | 
| 307 | 
             
                        for i, example in enumerate(self.results["examples"][:3]):  # Show first 3 examples
         | 
| 308 | 
            -
                            summary += f"### Example {i+1} | 
| 309 | 
             
                            summary += f"**Input:** {example['input'][:100]}...\n"
         | 
| 310 | 
             
                            summary += f"**Expected:** {example['expected'][:100]}...\n"
         | 
| 311 | 
             
                            summary += f"**Predicted:** {example['predicted'][:100]}...\n"
         | 
|  | |
| 52 | 
             
                    print(f"✅ Loaded {len(examples)} examples from {dataset_path}")
         | 
| 53 | 
             
                    return examples
         | 
| 54 |  | 
| 55 | 
            +
                # Removed domain categorization as requested
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 56 |  | 
| 57 | 
             
                def extract_placeholders(self, text: str) -> List[str]:
         | 
| 58 | 
             
                    """Extract all placeholder tags from text (e.g., [NAME_1], [DOB_1])"""
         | 
|  | |
| 60 | 
             
                    pattern = r'\[([A-Z_]+_\d+)\]'
         | 
| 61 | 
             
                    return re.findall(pattern, text)
         | 
| 62 |  | 
| 63 | 
            +
                def calculate_pii_detection_rate(self, input_text: str, predicted: str) -> float:
         | 
| 64 | 
            +
                    """Calculate PII detection rate - if input has PII and output has placeholders, count as success"""
         | 
| 65 | 
            +
                    # Check if input contains any PII patterns
         | 
| 66 | 
            +
                    input_has_pii = self._input_contains_pii(input_text)
         | 
| 67 |  | 
| 68 | 
            +
                    if not input_has_pii:
         | 
| 69 | 
            +
                        return 1.0  # No PII in input, so detection is perfect
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    # Check if output contains any placeholders at all
         | 
| 72 | 
            +
                    predicted_placeholders = self.extract_placeholders(predicted)
         | 
| 73 | 
            +
                    output_has_placeholders = len(predicted_placeholders) > 0
         | 
| 74 |  | 
| 75 | 
            +
                    # If input has PII and output has placeholders, count as successful detection
         | 
| 76 | 
            +
                    return 1.0 if output_has_placeholders else 0.0
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def _input_contains_pii(self, input_text: str) -> bool:
         | 
| 79 | 
            +
                    """Check if input text contains personal identifiable information"""
         | 
| 80 | 
            +
                    pii_patterns = [
         | 
| 81 | 
            +
                        r'\b\d{4}-\d{2}-\d{2}\b',  # Dates like 1985-03-15
         | 
| 82 | 
            +
                        r'\b\d{1,3}/\d{1,2}/\d{4}\b',  # Dates like 05/12/1980
         | 
| 83 | 
            +
                        r'\b\d{1,3}\s+[A-Z][a-z]+\s+(?:St|Street|Ave|Avenue|Rd|Road|Blvd|Boulevard)\b',  # Addresses
         | 
| 84 | 
            +
                        r'\(\d{3}\)\s*\d{3}-\d{4}\b',  # Phone numbers like (555) 123-4567
         | 
| 85 | 
            +
                        r'\+?\d{1,3}[-.\s]?\d{3}[-.\s]?\d{4}\b',  # International phone numbers
         | 
| 86 | 
            +
                        r'\b[A-Z][a-z]+\s+[A-Z][a-z]+\b',  # Names (First Last)
         | 
| 87 | 
            +
                        r'\b[A-Z][a-z]+\s+[A-Z]\.\s*[A-Z][a-z]+\b',  # Names with middle initial
         | 
| 88 | 
            +
                        r'\b\d+@\w+\.\w+\b',  # Email addresses
         | 
| 89 | 
            +
                        r'\b[A-Z]{2,}\d+\b',  # IDs like EMP-001-XYZ
         | 
| 90 | 
            +
                        r'\$\d{1,3}(?:,\d{3})*(?:\.\d{2})?\b',  # Monetary amounts like $85,000
         | 
| 91 | 
            +
                        r'\b\d{3}-\d{2}-\d{4}\b',  # SSN-like patterns
         | 
| 92 | 
            +
                        r'\b(?:Mr|Mrs|Ms|Dr|Prof)\.\s+[A-Z][a-z]+\b',  # Titles with names
         | 
| 93 | 
            +
                        r'\b\d{5}(?:-\d{4})?\b',  # ZIP codes
         | 
| 94 | 
            +
                        r'\b[A-Z][a-z]+,\s+[A-Z]{2}\s+\d{5}\b',  # City, State ZIP
         | 
| 95 | 
            +
                    ]
         | 
| 96 |  | 
| 97 | 
            +
                    return any(re.search(pattern, input_text) for pattern in pii_patterns)
         | 
|  | |
|  | |
| 98 |  | 
| 99 | 
             
                def calculate_completeness(self, predicted: str) -> bool:
         | 
| 100 | 
             
                    """Check if response appears to have no obvious PII remaining"""
         | 
|  | |
| 192 | 
             
                    total_completeness = 0
         | 
| 193 | 
             
                    total_semantic_preservation = 0
         | 
| 194 | 
             
                    total_latency = 0
         | 
|  | |
|  | |
| 195 |  | 
| 196 | 
             
                    successful_requests = 0
         | 
| 197 |  | 
|  | |
| 203 | 
             
                        input_text = example[self.config["datasets"]["benchmark_dataset"]["input_field"]]
         | 
| 204 | 
             
                        expected_output = example[self.config["datasets"]["benchmark_dataset"]["expected_output_field"]]
         | 
| 205 |  | 
|  | |
|  | |
|  | |
|  | |
| 206 | 
             
                        # Call model
         | 
| 207 | 
             
                        predicted_output, latency = self.call_model(instruction, input_text)
         | 
| 208 |  | 
|  | |
| 210 | 
             
                            successful_requests += 1
         | 
| 211 |  | 
| 212 | 
             
                            # Calculate metrics
         | 
| 213 | 
            +
                            pii_detection = self.calculate_pii_detection_rate(input_text, predicted_output)
         | 
| 214 | 
             
                            completeness = self.calculate_completeness(predicted_output)
         | 
| 215 | 
             
                            semantic_preservation = self.calculate_semantic_preservation(predicted_output, expected_output)
         | 
| 216 |  | 
|  | |
| 220 | 
             
                            total_semantic_preservation += semantic_preservation
         | 
| 221 | 
             
                            total_latency += latency
         | 
| 222 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 223 | 
             
                            # Store example if requested
         | 
| 224 | 
             
                            if len(self.results["examples"]) < self.config["output"]["max_examples"]:
         | 
| 225 | 
             
                                self.results["examples"].append({
         | 
| 226 | 
             
                                    "input": input_text,
         | 
| 227 | 
             
                                    "expected": expected_output,
         | 
| 228 | 
             
                                    "predicted": predicted_output,
         | 
|  | |
| 229 | 
             
                                    "metrics": {
         | 
| 230 | 
             
                                        "pii_detection": pii_detection,
         | 
| 231 | 
             
                                        "completeness": completeness,
         | 
|  | |
| 245 | 
             
                            "total_requests": len(examples)
         | 
| 246 | 
             
                        }
         | 
| 247 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 248 | 
             
                    self.save_results()
         | 
| 249 |  | 
| 250 | 
             
                def save_results(self):
         | 
|  | |
| 281 | 
             
            | Semantic Preservation | {m.get('semantic_preservation', 0):.3f} | How well meaning is preserved |
         | 
| 282 | 
             
            | Average Latency | {m.get('average_latency_ms', 0):.1f}ms | Response time performance |
         | 
| 283 |  | 
| 284 | 
            +
            ## Key Improvements
         | 
| 285 |  | 
| 286 | 
            +
            - **PII Detection**: Now measures if model generates ANY placeholders when PII is present in input
         | 
| 287 | 
            +
            - **Unified Evaluation**: All examples evaluated together (no domain separation)
         | 
| 288 | 
            +
            - **Lenient Scoring**: Focuses on detection capability rather than exact placeholder matching
         | 
| 289 |  | 
| 290 | 
            +
            """
         | 
|  | |
|  | |
|  | |
|  | |
| 291 |  | 
| 292 | 
             
                    if self.config["output"]["include_examples"] and self.results["examples"]:
         | 
| 293 | 
             
                        summary += "## Example Results\n\n"
         | 
| 294 | 
             
                        for i, example in enumerate(self.results["examples"][:3]):  # Show first 3 examples
         | 
| 295 | 
            +
                            summary += f"### Example {i+1}\n"
         | 
| 296 | 
             
                            summary += f"**Input:** {example['input'][:100]}...\n"
         | 
| 297 | 
             
                            summary += f"**Expected:** {example['expected'][:100]}...\n"
         | 
| 298 | 
             
                            summary += f"**Predicted:** {example['predicted'][:100]}...\n"
         |