Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.models as models | |
| import torchvision.transforms as transforms | |
| from torchvision.models.resnet import ResNet50_Weights | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| import requests | |
| import time | |
| import copy | |
| from collections import OrderedDict | |
| from pathlib import Path | |
| # Check for available hardware acceleration | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| device = torch.device("mps") # Use Apple Metal Performance Shaders for M-series Macs | |
| else: | |
| device = torch.device("cpu") | |
| print(f"Using device: {device}") | |
| # Constants | |
| MODEL_URLS = { | |
| 'resnet50_robust': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps3.ckpt', | |
| 'resnet50_standard': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps0.ckpt', | |
| 'resnet50_robust_face': 'https://huggingface.co/ttoosi/resnet50_robust_face/blob/main/100_checkpoint.pt' | |
| } | |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
| IMAGENET_STD = [0.229, 0.224, 0.225] | |
| # Define the transforms based on whether normalization is on or off | |
| def get_transform(input_size=224, normalize=False, norm_mean=IMAGENET_MEAN, norm_std=IMAGENET_STD): | |
| if normalize: | |
| return transforms.Compose([ | |
| transforms.Resize(input_size), | |
| transforms.CenterCrop(input_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize(norm_mean, norm_std), | |
| ]) | |
| else: | |
| return transforms.Compose([ | |
| transforms.Resize(input_size), | |
| transforms.CenterCrop(input_size), | |
| transforms.ToTensor(), | |
| ]) | |
| # Default transform without normalization | |
| transform = transforms.Compose([ | |
| transforms.Resize(224), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| ]) | |
| normalize_transform = transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD) | |
| def extract_middle_layers(model, layer_index): | |
| """ | |
| Extract a subset of the model up to a specific layer. | |
| Args: | |
| model: The neural network model | |
| layer_index: String 'all' for the full model, or a layer identifier (string or int) | |
| For ResNet: integers 0-8 representing specific layers | |
| For ViT: strings like 'encoder.layers.encoder_layer_3' | |
| Returns: | |
| A modified model that outputs features from the specified layer | |
| """ | |
| if isinstance(layer_index, str) and layer_index == 'all': | |
| return model | |
| # Special case for ViT's encoder layers with DataParallel wrapper | |
| if isinstance(layer_index, str) and layer_index.startswith('encoder.layers.encoder_layer_'): | |
| try: | |
| target_layer_idx = int(layer_index.split('_')[-1]) | |
| # Create a deep copy of the model to avoid modifying the original | |
| new_model = copy.deepcopy(model) | |
| # For models wrapped in DataParallel | |
| if hasattr(new_model, 'module'): | |
| # Create a subset of encoder layers up to the specified index | |
| encoder_layers = nn.Sequential() | |
| for i in range(target_layer_idx + 1): | |
| layer_name = f"encoder_layer_{i}" | |
| if hasattr(new_model.module.encoder.layers, layer_name): | |
| encoder_layers.add_module(layer_name, | |
| getattr(new_model.module.encoder.layers, layer_name)) | |
| # Replace the encoder layers with our truncated version | |
| new_model.module.encoder.layers = encoder_layers | |
| # Remove the heads since we're stopping at the encoder layer | |
| new_model.module.heads = nn.Identity() | |
| return new_model | |
| else: | |
| # Direct model access (not DataParallel) | |
| encoder_layers = nn.Sequential() | |
| for i in range(target_layer_idx + 1): | |
| layer_name = f"encoder_layer_{i}" | |
| if hasattr(new_model.encoder.layers, layer_name): | |
| encoder_layers.add_module(layer_name, | |
| getattr(new_model.encoder.layers, layer_name)) | |
| # Replace the encoder layers with our truncated version | |
| new_model.encoder.layers = encoder_layers | |
| # Remove the heads since we're stopping at the encoder layer | |
| new_model.heads = nn.Identity() | |
| return new_model | |
| except (ValueError, IndexError) as e: | |
| raise ValueError(f"Invalid ViT layer specification: {layer_index}. Error: {e}") | |
| # Handling for ViT whole blocks | |
| elif hasattr(model, 'blocks') or (hasattr(model, 'module') and hasattr(model.module, 'blocks')): | |
| # Check for DataParallel wrapper | |
| base_model = model.module if hasattr(model, 'module') else model | |
| # Create a deep copy to avoid modifying the original | |
| new_model = copy.deepcopy(model) | |
| base_new_model = new_model.module if hasattr(new_model, 'module') else new_model | |
| # Add the desired number of transformer blocks | |
| if isinstance(layer_index, int): | |
| # Truncate the blocks | |
| base_new_model.blocks = base_new_model.blocks[:layer_index+1] | |
| return new_model | |
| else: | |
| # Original ResNet/VGG handling | |
| modules = list(model.named_children()) | |
| print(f"DEBUG - extract_middle_layers - Looking for '{layer_index}' in {[name for name, _ in modules]}") | |
| cutoff_idx = next((i for i, (name, _) in enumerate(modules) | |
| if name == str(layer_index)), None) | |
| if cutoff_idx is not None: | |
| # Keep modules up to and including the target | |
| new_model = nn.Sequential(OrderedDict(modules[:cutoff_idx+1])) | |
| return new_model | |
| else: | |
| raise ValueError(f"Module {layer_index} not found in model") | |
| # Get ImageNet labels | |
| def get_imagenet_labels(): | |
| url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" | |
| response = requests.get(url) | |
| if response.status_code == 200: | |
| return response.json() | |
| else: | |
| raise RuntimeError("Failed to fetch ImageNet labels") | |
| # Download model if needed | |
| def download_model(model_type): | |
| if model_type not in MODEL_URLS or MODEL_URLS[model_type] is None: | |
| return None # Use PyTorch's pretrained model | |
| # Handle special case for face model | |
| if model_type == 'resnet50_robust_face': | |
| model_path = Path("models/resnet50_robust_face_100_checkpoint.pt") | |
| else: | |
| model_path = Path(f"models/{model_type}.pt") | |
| if not model_path.exists(): | |
| print(f"Downloading {model_type} model...") | |
| url = MODEL_URLS[model_type] | |
| response = requests.get(url, stream=True) | |
| if response.status_code == 200: | |
| with open(model_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| print(f"Model downloaded and saved to {model_path}") | |
| else: | |
| raise RuntimeError(f"Failed to download model: {response.status_code}") | |
| return model_path | |
| class NormalizeByChannelMeanStd(nn.Module): | |
| def __init__(self, mean, std): | |
| super(NormalizeByChannelMeanStd, self).__init__() | |
| if not isinstance(mean, torch.Tensor): | |
| mean = torch.tensor(mean) | |
| if not isinstance(std, torch.Tensor): | |
| std = torch.tensor(std) | |
| self.register_buffer("mean", mean) | |
| self.register_buffer("std", std) | |
| def forward(self, tensor): | |
| return self.normalize_fn(tensor, self.mean, self.std) | |
| def normalize_fn(self, tensor, mean, std): | |
| """Differentiable version of torchvision.functional.normalize""" | |
| # here we assume the color channel is at dim=1 | |
| mean = mean[None, :, None, None] | |
| std = std[None, :, None, None] | |
| return tensor.sub(mean).div(std) | |
| class InferStep: | |
| def __init__(self, orig_image, eps, step_size): | |
| self.orig_image = orig_image | |
| self.eps = eps | |
| self.step_size = step_size | |
| def project(self, x): | |
| diff = x - self.orig_image | |
| diff = torch.clamp(diff, -self.eps, self.eps) | |
| return torch.clamp(self.orig_image + diff, 0, 1) | |
| def step(self, x, grad): | |
| l = len(x.shape) - 1 | |
| grad_norm = torch.norm(grad.view(grad.shape[0], -1), dim=1).view(-1, *([1]*l)) | |
| scaled_grad = grad / (grad_norm + 1e-10) | |
| return scaled_grad * self.step_size | |
| def get_iterations_to_show(n_itr): | |
| """Generate a dynamic list of iterations to show based on total iterations.""" | |
| if n_itr <= 50: | |
| return [1, 5, 10, 20, 30, 40, 50, n_itr] | |
| elif n_itr <= 100: | |
| return [1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, n_itr] | |
| elif n_itr <= 200: | |
| return [1, 5, 10, 20, 30, 40, 50, 75, 100, 125, 150, 175, 200, n_itr] | |
| elif n_itr <= 500: | |
| return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500, n_itr] | |
| else: | |
| # For very large iterations, show more evenly distributed points | |
| return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500, | |
| int(n_itr*0.6), int(n_itr*0.7), int(n_itr*0.8), int(n_itr*0.9), n_itr] | |
| def get_inference_configs(inference_type='IncreaseConfidence', eps=0.5, n_itr=50, step_size=1.0): | |
| """Generate inference configuration with customizable parameters. | |
| Args: | |
| inference_type (str): Type of inference ('IncreaseConfidence' or 'Prior-Guided Drift Diffusion') | |
| eps (float): Maximum perturbation size | |
| n_itr (int): Number of iterations | |
| step_size (float): Step size for each iteration | |
| """ | |
| # Base configuration common to all inference types | |
| config = { | |
| 'loss_infer': inference_type, # How to guide the optimization | |
| 'n_itr': n_itr, # Number of iterations | |
| 'eps': eps, # Maximum perturbation size | |
| 'step_size': step_size, # Step size for each iteration | |
| 'diffusion_noise_ratio': 0.0, # No diffusion noise | |
| 'initial_inference_noise_ratio': 0.0, # No initial noise | |
| 'top_layer': 'all', # Use all layers of the model | |
| 'inference_normalization': False, # Apply normalization during inference | |
| 'recognition_normalization': False, # Apply normalization during recognition | |
| 'iterations_to_show': get_iterations_to_show(n_itr), # Dynamic iterations to visualize | |
| 'misc_info': {'keep_grads': False} # Additional configuration | |
| } | |
| # Customize based on inference type | |
| if inference_type == 'IncreaseConfidence': | |
| config['loss_function'] = 'CE' # Cross Entropy | |
| elif inference_type == 'Prior-Guided Drift Diffusion': | |
| config['loss_function'] = 'MSE' # Mean Square Error | |
| config['initial_inference_noise_ratio'] = 0.05 # Initial noise for diffusion | |
| config['diffusion_noise_ratio'] = 0.01 # Add noise during diffusion | |
| elif inference_type == 'GradModulation': | |
| config['loss_function'] = 'CE' # Cross Entropy | |
| config['misc_info']['grad_modulation'] = 0.5 # Gradient modulation strength | |
| elif inference_type == 'CompositionalFusion': | |
| config['loss_function'] = 'CE' # Cross Entropy | |
| config['misc_info']['positive_classes'] = [] # Classes to maximize | |
| config['misc_info']['negative_classes'] = [] # Classes to minimize | |
| return config | |
| class GenerativeInferenceModel: | |
| def __init__(self): | |
| self.models = {} | |
| self.normalizer = NormalizeByChannelMeanStd(IMAGENET_MEAN, IMAGENET_STD).to(device) | |
| self.labels = get_imagenet_labels() | |
| def verify_model_integrity(self, model, model_type): | |
| """ | |
| Verify model integrity by running a test input through it. | |
| Returns whether the model passes basic integrity check. | |
| """ | |
| try: | |
| print(f"\n=== Running model integrity check for {model_type} ===") | |
| # Create a deterministic test input directly on the correct device | |
| test_input = torch.zeros(1, 3, 224, 224, device=device) | |
| test_input[0, 0, 100:124, 100:124] = 0.5 # Red square | |
| # Run forward pass | |
| with torch.no_grad(): | |
| output = model(test_input) | |
| # Check output shape | |
| if output.shape != (1, 1000): | |
| print(f"❌ Unexpected output shape: {output.shape}, expected (1, 1000)") | |
| return False | |
| # Get top prediction | |
| probs = torch.nn.functional.softmax(output, dim=1) | |
| confidence, prediction = torch.max(probs, 1) | |
| # Calculate basic statistics on output | |
| mean = output.mean().item() | |
| std = output.std().item() | |
| min_val = output.min().item() | |
| max_val = output.max().item() | |
| print(f"Model integrity check results:") | |
| print(f"- Output shape: {output.shape}") | |
| print(f"- Top prediction: Class {prediction.item()} with {confidence.item()*100:.2f}% confidence") | |
| print(f"- Output statistics: mean={mean:.3f}, std={std:.3f}, min={min_val:.3f}, max={max_val:.3f}") | |
| # Basic sanity checks | |
| if torch.isnan(output).any(): | |
| print("❌ Model produced NaN outputs") | |
| return False | |
| if output.std().item() < 0.1: | |
| print("⚠️ Low output variance, model may not be discriminative") | |
| print("✅ Model passes basic integrity check") | |
| return True | |
| except Exception as e: | |
| print(f"❌ Model integrity check failed with error: {e}") | |
| # Rather than failing completely, we'll continue | |
| return True | |
| def load_model(self, model_type): | |
| """Load model from checkpoint or use pretrained model.""" | |
| if model_type in self.models: | |
| print(f"Using cached {model_type} model") | |
| return self.models[model_type] | |
| # Record loading time for performance analysis | |
| start_time = time.time() | |
| model_path = download_model(model_type) | |
| # Create a sequential model with normalizer and ResNet50 | |
| resnet = models.resnet50() | |
| model = nn.Sequential( | |
| self.normalizer, # Normalizer is part of the model sequence | |
| resnet | |
| ) | |
| # Load the model checkpoint | |
| if model_path: | |
| print(f"Loading {model_type} model from {model_path}...") | |
| try: | |
| checkpoint = torch.load(model_path, map_location=device) | |
| # Print checkpoint structure for better understanding | |
| print("\n=== Analyzing checkpoint structure ===") | |
| if isinstance(checkpoint, dict): | |
| print(f"Checkpoint contains keys: {list(checkpoint.keys())}") | |
| # Examine 'model' structure if it exists | |
| if 'model' in checkpoint and isinstance(checkpoint['model'], dict): | |
| model_dict = checkpoint['model'] | |
| # Get sample of keys to understand structure | |
| first_keys = list(model_dict.keys())[:5] | |
| print(f"'model' contains keys like: {first_keys}") | |
| # Check for common prefixes in the model dict | |
| prefixes = set() | |
| for key in list(model_dict.keys())[:100]: # Check first 100 keys | |
| parts = key.split('.') | |
| if len(parts) > 1: | |
| prefixes.add(parts[0]) | |
| if prefixes: | |
| print(f"Common prefixes in model dict: {prefixes}") | |
| else: | |
| print(f"Checkpoint is not a dictionary, but a {type(checkpoint)}") | |
| # Handle different checkpoint formats | |
| if 'model' in checkpoint: | |
| # Format from madrylab robust models | |
| state_dict = checkpoint['model'] | |
| print("Using 'model' key from checkpoint") | |
| elif 'state_dict' in checkpoint: | |
| state_dict = checkpoint['state_dict'] | |
| print("Using 'state_dict' key from checkpoint") | |
| else: | |
| # Direct state dict | |
| state_dict = checkpoint | |
| print("Using checkpoint directly as state_dict") | |
| # Handle prefix in state dict keys for ResNet part | |
| resnet_state_dict = {} | |
| prefixes_to_try = ['', 'module.', 'model.', 'attacker.model.'] | |
| resnet_keys = set(resnet.state_dict().keys()) | |
| # First check if we can find keys directly in the attacker.model path | |
| print("\n=== Phase 1: Checking for specific model structures ===") | |
| # Check for 'module.model' structure (seen in actual checkpoint) | |
| module_model_keys = [key for key in state_dict.keys() if key.startswith('module.model.')] | |
| if module_model_keys: | |
| print(f"Found 'module.model' structure with {len(module_model_keys)} parameters") | |
| # Extract all parameters from module.model | |
| for source_key, value in state_dict.items(): | |
| if source_key.startswith('module.model.'): | |
| target_key = source_key[len('module.model.'):] | |
| resnet_state_dict[target_key] = value | |
| print(f"Extracted {len(resnet_state_dict)} parameters from module.model") | |
| # Check for 'attacker.model' structure | |
| attacker_model_keys = [key for key in state_dict.keys() if key.startswith('attacker.model.')] | |
| if attacker_model_keys: | |
| print(f"Found 'attacker.model' structure with {len(attacker_model_keys)} parameters") | |
| # Extract all parameters from attacker.model | |
| for source_key, value in state_dict.items(): | |
| if source_key.startswith('attacker.model.'): | |
| target_key = source_key[len('attacker.model.'):] | |
| resnet_state_dict[target_key] = value | |
| print(f"Extracted {len(resnet_state_dict)} parameters from attacker.model") | |
| # Check if 'model' (not attacker.model) exists as a fallback | |
| model_keys = [key for key in state_dict.keys() if key.startswith('model.') and not key.startswith('attacker.model.')] | |
| if model_keys and len(resnet_state_dict) < len(resnet_keys): | |
| print(f"Found additional 'model.' structure with {len(model_keys)} parameters") | |
| # Try to complete missing parameters | |
| for source_key, value in state_dict.items(): | |
| if source_key.startswith('model.'): | |
| target_key = source_key[len('model.'):] | |
| if target_key in resnet_keys and target_key not in resnet_state_dict: | |
| resnet_state_dict[target_key] = value | |
| else: | |
| # Check for other known structures | |
| structure_found = False | |
| # Check for 'model.' prefix | |
| model_keys = [key for key in state_dict.keys() if key.startswith('model.')] | |
| if model_keys: | |
| print(f"Found 'model.' structure with {len(model_keys)} parameters") | |
| for source_key, value in state_dict.items(): | |
| if source_key.startswith('model.'): | |
| target_key = source_key[len('model.'):] | |
| resnet_state_dict[target_key] = value | |
| structure_found = True | |
| # Check for ResNet parameters at the top level | |
| top_level_resnet_keys = 0 | |
| for key in resnet_keys: | |
| if key in state_dict: | |
| top_level_resnet_keys += 1 | |
| if top_level_resnet_keys > 0: | |
| print(f"Found {top_level_resnet_keys} ResNet parameters at top level") | |
| for target_key in resnet_keys: | |
| if target_key in state_dict: | |
| resnet_state_dict[target_key] = state_dict[target_key] | |
| structure_found = True | |
| # If no structure was recognized, try the prefix mapping approach | |
| if not structure_found: | |
| print("No standard model structure found, trying prefix mappings...") | |
| for target_key in resnet_keys: | |
| for prefix in prefixes_to_try: | |
| source_key = prefix + target_key | |
| if source_key in state_dict: | |
| resnet_state_dict[target_key] = state_dict[source_key] | |
| break | |
| # If we still can't find enough keys, try a final approach of removing prefixes | |
| if len(resnet_state_dict) < len(resnet_keys): | |
| print(f"Found only {len(resnet_state_dict)}/{len(resnet_keys)} parameters, trying prefix removal...") | |
| # Track matches found through prefix removal | |
| prefix_matches = {prefix: 0 for prefix in ['module.', 'model.', 'attacker.model.', 'attacker.']} | |
| layer_matches = {} # Track matches by layer type | |
| # Count parameter keys by layer type for analysis | |
| for key in resnet_keys: | |
| layer_name = key.split('.')[0] if '.' in key else key | |
| if layer_name not in layer_matches: | |
| layer_matches[layer_name] = {'total': 0, 'matched': 0} | |
| layer_matches[layer_name]['total'] += 1 | |
| # Try keys with common prefixes | |
| for source_key, value in state_dict.items(): | |
| # Skip if already found | |
| target_key = source_key | |
| matched_prefix = None | |
| # Try removing various prefixes | |
| for prefix in ['module.', 'model.', 'attacker.model.', 'attacker.']: | |
| if source_key.startswith(prefix): | |
| target_key = source_key[len(prefix):] | |
| matched_prefix = prefix | |
| break | |
| # If the target key is in the ResNet keys, add it to the state dict | |
| if target_key in resnet_keys and target_key not in resnet_state_dict: | |
| resnet_state_dict[target_key] = value | |
| # Update match statistics | |
| if matched_prefix: | |
| prefix_matches[matched_prefix] += 1 | |
| # Update layer matches | |
| layer_name = target_key.split('.')[0] if '.' in target_key else target_key | |
| if layer_name in layer_matches: | |
| layer_matches[layer_name]['matched'] += 1 | |
| # Print detailed prefix removal statistics | |
| print("\n=== Prefix Removal Statistics ===") | |
| total_matches = sum(prefix_matches.values()) | |
| print(f"Total parameters matched through prefix removal: {total_matches}/{len(resnet_keys)} ({(total_matches/len(resnet_keys))*100:.1f}%)") | |
| # Show matches by prefix | |
| print("\nMatches by prefix:") | |
| for prefix, count in sorted(prefix_matches.items(), key=lambda x: x[1], reverse=True): | |
| if count > 0: | |
| print(f" {prefix}: {count} parameters") | |
| # Show matches by layer type | |
| print("\nMatches by layer type:") | |
| for layer, stats in sorted(layer_matches.items(), key=lambda x: x[1]['total'], reverse=True): | |
| match_percent = (stats['matched'] / stats['total']) * 100 if stats['total'] > 0 else 0 | |
| print(f" {layer}: {stats['matched']}/{stats['total']} ({match_percent:.1f}%)") | |
| # Check for specific important layers (conv1, layer1, etc.) | |
| critical_layers = ['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'layer4', 'fc'] | |
| print("\nStatus of critical layers:") | |
| for layer in critical_layers: | |
| if layer in layer_matches: | |
| match_percent = (layer_matches[layer]['matched'] / layer_matches[layer]['total']) * 100 | |
| status = "✅ COMPLETE" if layer_matches[layer]['matched'] == layer_matches[layer]['total'] else "⚠️ INCOMPLETE" | |
| print(f" {layer}: {layer_matches[layer]['matched']}/{layer_matches[layer]['total']} ({match_percent:.1f}%) - {status}") | |
| else: | |
| print(f" {layer}: Not found in model") | |
| # Load the ResNet state dict | |
| if resnet_state_dict: | |
| try: | |
| # Use strict=False to allow missing keys | |
| result = resnet.load_state_dict(resnet_state_dict, strict=False) | |
| missing_keys, unexpected_keys = result | |
| # Generate detailed information with better formatting | |
| loading_report = [] | |
| loading_report.append(f"\n===== MODEL LOADING REPORT: {model_type} =====") | |
| loading_report.append(f"Total parameters in checkpoint: {len(resnet_state_dict):,}") | |
| loading_report.append(f"Total parameters in model: {len(resnet.state_dict()):,}") | |
| loading_report.append(f"Missing keys: {len(missing_keys):,} parameters") | |
| loading_report.append(f"Unexpected keys: {len(unexpected_keys):,} parameters") | |
| # Calculate percentage of parameters loaded | |
| loaded_keys = set(resnet_state_dict.keys()) - set(unexpected_keys) | |
| loaded_percent = (len(loaded_keys) / len(resnet.state_dict())) * 100 | |
| # Determine loading success status | |
| if loaded_percent >= 99.5: | |
| status = "✅ COMPLETE - All important parameters loaded" | |
| elif loaded_percent >= 90: | |
| status = "🟡 PARTIAL - Most parameters loaded, should still function" | |
| elif loaded_percent >= 50: | |
| status = "⚠️ INCOMPLETE - Many parameters missing, may not function properly" | |
| else: | |
| status = "❌ FAILED - Critical parameters missing, will not function properly" | |
| loading_report.append(f"Successfully loaded: {len(loaded_keys):,} parameters ({loaded_percent:.1f}%)") | |
| loading_report.append(f"Loading status: {status}") | |
| # If loading is severely incomplete, fall back to PyTorch's pretrained model | |
| if loaded_percent < 50: | |
| loading_report.append("\n⚠️ WARNING: Loading from checkpoint is too incomplete.") | |
| loading_report.append("⚠️ Falling back to PyTorch's pretrained model to avoid broken inference.") | |
| # Create a new ResNet model with pretrained weights | |
| resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) | |
| model = nn.Sequential(self.normalizer, resnet) | |
| loading_report.append("✅ Successfully loaded PyTorch's pretrained ResNet50 model") | |
| # Show missing keys by layer type | |
| if missing_keys: | |
| loading_report.append("\nMissing keys by layer type:") | |
| layer_types = {} | |
| for key in missing_keys: | |
| # Extract layer type (e.g., 'conv', 'bn', 'layer1', etc.) | |
| parts = key.split('.') | |
| if len(parts) > 0: | |
| layer_type = parts[0] | |
| if layer_type not in layer_types: | |
| layer_types[layer_type] = 0 | |
| layer_types[layer_type] += 1 | |
| # Add counts by layer type | |
| for layer_type, count in sorted(layer_types.items(), key=lambda x: x[1], reverse=True): | |
| loading_report.append(f" {layer_type}: {count:,} parameters") | |
| loading_report.append("\nFirst 10 missing keys:") | |
| for i, key in enumerate(sorted(missing_keys)[:10]): | |
| loading_report.append(f" {i+1}. {key}") | |
| # Show unexpected keys if any | |
| if unexpected_keys: | |
| loading_report.append("\nFirst 10 unexpected keys:") | |
| for i, key in enumerate(sorted(unexpected_keys)[:10]): | |
| loading_report.append(f" {i+1}. {key}") | |
| loading_report.append("========================================") | |
| # Convert report to string and print it | |
| report_text = "\n".join(loading_report) | |
| print(report_text) | |
| # Also save to a file for reference | |
| os.makedirs("logs", exist_ok=True) | |
| with open(f"logs/model_loading_{model_type}.log", "w") as f: | |
| f.write(report_text) | |
| # Look for normalizer parameters as well | |
| if any(key.startswith('attacker.normalize.') for key in state_dict.keys()): | |
| norm_state_dict = {} | |
| for key, value in state_dict.items(): | |
| if key.startswith('attacker.normalize.'): | |
| norm_key = key[len('attacker.normalize.'):] | |
| norm_state_dict[norm_key] = value | |
| if norm_state_dict: | |
| try: | |
| self.normalizer.load_state_dict(norm_state_dict, strict=False) | |
| print("Successfully loaded normalizer parameters") | |
| except Exception as e: | |
| print(f"Warning: Could not load normalizer parameters: {e}") | |
| except Exception as e: | |
| print(f"Warning: Error loading ResNet parameters: {e}") | |
| # Fall back to loading without normalizer | |
| model = resnet # Use just the ResNet model without normalizer | |
| except Exception as e: | |
| print(f"Error loading model checkpoint: {e}") | |
| # Fallback to PyTorch's pretrained model | |
| print("Falling back to PyTorch's pretrained model") | |
| resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) | |
| model = nn.Sequential(self.normalizer, resnet) | |
| else: | |
| # Fallback to PyTorch's pretrained model | |
| print("No checkpoint available, using PyTorch's pretrained model") | |
| resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) | |
| model = nn.Sequential(self.normalizer, resnet) | |
| model = model.to(device) | |
| model.eval() # Set to evaluation mode | |
| # Verify model integrity | |
| self.verify_model_integrity(model, model_type) | |
| # Store the model for future use | |
| self.models[model_type] = model | |
| end_time = time.time() | |
| load_time = end_time - start_time | |
| print(f"Model {model_type} loaded in {load_time:.2f} seconds") | |
| return model | |
| def inference(self, image, model_type, config): | |
| """Run generative inference on the image.""" | |
| # Time the entire inference process | |
| inference_start = time.time() | |
| # Load model if not already loaded | |
| model = self.load_model(model_type) | |
| # Check if image is a file path | |
| if isinstance(image, str): | |
| if os.path.exists(image): | |
| image = Image.open(image).convert('RGB') | |
| else: | |
| raise ValueError(f"Image path does not exist: {image}") | |
| elif isinstance(image, torch.Tensor): | |
| raise ValueError(f"Image type {type(image)}, looks like already a transformed tensor") | |
| # Prepare image tensor - match original code's conditional transform | |
| load_start = time.time() | |
| use_norm = config['inference_normalization'] == 'on' | |
| custom_transform = get_transform( | |
| input_size=224, | |
| normalize=use_norm, | |
| norm_mean=IMAGENET_MEAN, | |
| norm_std=IMAGENET_STD | |
| ) | |
| # Special handling for GradModulation as in original | |
| if config['loss_infer'] == 'GradModulation' and 'misc_info' in config and 'grad_modulation' in config['misc_info']: | |
| grad_modulation = config['misc_info']['grad_modulation'] | |
| image_tensor = custom_transform(image).unsqueeze(0).to(device) | |
| image_tensor = image_tensor * (1-grad_modulation) + grad_modulation * torch.randn_like(image_tensor).to(device) | |
| else: | |
| image_tensor = custom_transform(image).unsqueeze(0).to(device) | |
| image_tensor.requires_grad = True | |
| print(f"Image loaded and processed in {time.time() - load_start:.2f} seconds") | |
| # Check model structure | |
| is_sequential = isinstance(model, nn.Sequential) | |
| # Get original predictions | |
| with torch.no_grad(): | |
| # If the model is sequential with a normalizer, skip the normalization step | |
| if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd): | |
| print("Model is sequential with normalization") | |
| # Get the core model part (typically at index 1 in Sequential) | |
| core_model = model[1] | |
| if config['inference_normalization']: | |
| output_original = model(image_tensor) # Model includes normalization | |
| else: | |
| output_original = core_model(image_tensor) # Model includes normalization | |
| else: | |
| print("Model is not sequential with normalization") | |
| # Use manual normalization for non-sequential models | |
| if config['inference_normalization']: | |
| normalized_tensor = normalize_transform(image_tensor) | |
| output_original = model(normalized_tensor) | |
| else: | |
| output_original = model(image_tensor) | |
| core_model = model | |
| probs_orig = F.softmax(output_original, dim=1) | |
| conf_orig, classes_orig = torch.max(probs_orig, 1) | |
| # Get least confident classes | |
| _, least_confident_classes = torch.topk(probs_orig, k=100, largest=False) | |
| # Initialize inference step | |
| infer_step = InferStep(image_tensor, config['eps'], config['step_size']) | |
| # Storage for inference steps | |
| # Create a new tensor that requires gradients | |
| x = image_tensor.clone().detach().requires_grad_(True) | |
| all_steps = [image_tensor[0].detach().cpu()] | |
| # For Prior-Guided Drift Diffusion, extract selected layer and initialize with noisy features | |
| noisy_features = None | |
| layer_model = None | |
| if config['loss_infer'] == 'Prior-Guided Drift Diffusion': | |
| print(f"Setting up Prior-Guided Drift Diffusion with layer {config['top_layer']} and noise {config['initial_inference_noise_ratio']}...") | |
| # Extract model up to the specified layer | |
| try: | |
| # Start by finding the actual model to use | |
| base_model = model | |
| # Handle DataParallel wrapper if present | |
| if hasattr(base_model, 'module'): | |
| base_model = base_model.module | |
| # Log the initial model structure | |
| print(f"DEBUG - Initial model structure: {type(base_model)}") | |
| # If we have a Sequential model (which is likely our normalizer + model structure) | |
| if isinstance(base_model, nn.Sequential): | |
| print(f"DEBUG - Sequential model with {len(list(base_model.children()))} children") | |
| # If this is our NormalizeByChannelMeanStd + ResNet pattern | |
| if len(list(base_model.children())) >= 2: | |
| # The actual ResNet model is the second component (index 1) | |
| actual_model = list(base_model.children())[1] | |
| print(f"DEBUG - Using ResNet component: {type(actual_model)}") | |
| print(f"DEBUG - Available layers: {[name for name, _ in actual_model.named_children()]}") | |
| # Extract from the actual ResNet | |
| layer_model = extract_middle_layers(actual_model, config['top_layer']) | |
| else: | |
| # Just a single component Sequential | |
| layer_model = extract_middle_layers(base_model, config['top_layer']) | |
| else: | |
| # Not Sequential, might be direct model | |
| print(f"DEBUG - Available layers: {[name for name, _ in base_model.named_children()]}") | |
| layer_model = extract_middle_layers(base_model, config['top_layer']) | |
| print(f"Successfully extracted model up to layer: {config['top_layer']}") | |
| except ValueError as e: | |
| print(f"Layer extraction failed: {e}. Using full model.") | |
| layer_model = model | |
| # Add noise to the image - exactly match original code | |
| added_noise = config['initial_inference_noise_ratio'] * torch.randn_like(image_tensor).to(device) | |
| noisy_image_tensor = image_tensor + added_noise | |
| # Compute noisy features - simplified to match original code | |
| noisy_features = layer_model(noisy_image_tensor) | |
| print(f"Noisy features computed for Prior-Guided Drift Diffusion target with shape: {noisy_features.shape if hasattr(noisy_features, 'shape') else 'unknown'}") | |
| # Main inference loop | |
| print(f"Starting inference loop with {config['n_itr']} iterations for {config['loss_infer']}...") | |
| loop_start = time.time() | |
| for i in range(config['n_itr']): | |
| # Reset gradients | |
| x.grad = None | |
| # Forward pass - use layer_model for Prior-Guided Drift Diffusion, full model otherwise | |
| if config['loss_infer'] == 'Prior-Guided Drift Diffusion' and layer_model is not None: | |
| # Use the extracted layer model for Prior-Guided Drift Diffusion | |
| # In original code, normalization is handled at transform time, not during forward pass | |
| output = layer_model(x) | |
| else: | |
| # Standard forward pass with full model | |
| # Simplified to match original code's approach | |
| output = model(x) | |
| # Calculate loss and gradients based on inference type | |
| try: | |
| if config['loss_infer'] == 'Prior-Guided Drift Diffusion': | |
| # Use MSE loss to match the noisy features | |
| assert config['loss_function'] == 'MSE', "Reverse Diffusion loss function must be MSE" | |
| if noisy_features is not None: | |
| loss = F.mse_loss(output, noisy_features) | |
| grad = torch.autograd.grad(loss, x)[0] # Removed retain_graph=True to match original | |
| else: | |
| raise ValueError("Noisy features not computed for Prior-Guided Drift Diffusion") | |
| else: # Default 'IncreaseConfidence' approach | |
| # Get the least confident classes | |
| num_classes = min(10, least_confident_classes.size(1)) | |
| target_classes = least_confident_classes[0, :num_classes] | |
| # Create targets for least confident classes | |
| targets = torch.tensor([idx.item() for idx in target_classes], device=device) | |
| # Use a combined loss to increase confidence | |
| loss = 0 | |
| for target in targets: | |
| # Create one-hot target | |
| one_hot = torch.zeros_like(output) | |
| one_hot[0, target] = 1 | |
| # Use loss to maximize confidence | |
| loss = loss + F.mse_loss(F.softmax(output, dim=1), one_hot) | |
| grad = torch.autograd.grad(loss, x, retain_graph=True)[0] | |
| if grad is None: | |
| print("Warning: Direct gradient calculation failed") | |
| # Fall back to random perturbation | |
| random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size'] | |
| x = infer_step.project(x + random_noise) | |
| else: | |
| # Update image with gradient - do this exactly as in original code | |
| adjusted_grad = infer_step.step(x, grad) | |
| # Add diffusion noise if specified | |
| diffusion_noise = config['diffusion_noise_ratio'] * torch.randn_like(x).to(device) | |
| # Apply gradient and noise in one operation before projecting, exactly as in original | |
| x = infer_step.project(x.clone() + adjusted_grad + diffusion_noise) | |
| except Exception as e: | |
| print(f"Error in gradient calculation: {e}") | |
| # Fall back to random perturbation - match original code | |
| random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size'] | |
| x = infer_step.project(x.clone() + random_noise) | |
| # Store step if in iterations_to_show | |
| if i+1 in config['iterations_to_show'] or i+1 == config['n_itr']: | |
| all_steps.append(x[0].detach().cpu()) | |
| # Print some info about the inference | |
| with torch.no_grad(): | |
| if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd): | |
| if config['inference_normalization']: | |
| final_output = model(x) | |
| else: | |
| final_output = core_model(x) | |
| else: | |
| if config['inference_normalization']: | |
| normalized_x = normalize_transform(x) | |
| final_output = model(normalized_x) | |
| else: | |
| final_output = model(x) | |
| final_probs = F.softmax(final_output, dim=1) | |
| final_conf, final_classes = torch.max(final_probs, 1) | |
| # Calculate timing information | |
| loop_time = time.time() - loop_start | |
| total_time = time.time() - inference_start | |
| avg_iter_time = loop_time / config['n_itr'] if config['n_itr'] > 0 else 0 | |
| print(f"Original top class: {classes_orig.item()} ({conf_orig.item():.4f})") | |
| print(f"Final top class: {final_classes.item()} ({final_conf.item():.4f})") | |
| print(f"Inference loop completed in {loop_time:.2f} seconds ({avg_iter_time:.4f} sec/iteration)") | |
| print(f"Total inference time: {total_time:.2f} seconds") | |
| # Return results in format compatible with both old and new code | |
| return { | |
| 'final_image': x[0].detach().cpu(), | |
| 'steps': all_steps, | |
| 'original_class': classes_orig.item(), | |
| 'original_confidence': conf_orig.item(), | |
| 'final_class': final_classes.item(), | |
| 'final_confidence': final_conf.item() | |
| } | |
| # Utility function to show inference steps | |
| def show_inference_steps(steps, figsize=(15, 10)): | |
| import matplotlib.pyplot as plt | |
| n_steps = len(steps) | |
| fig, axes = plt.subplots(1, n_steps, figsize=figsize) | |
| for i, step_img in enumerate(steps): | |
| img = step_img.permute(1, 2, 0).numpy() | |
| axes[i].imshow(img) | |
| axes[i].set_title(f"Step {i}") | |
| axes[i].axis('off') | |
| plt.tight_layout() | |
| return fig |