GenerativeInferenceDemo / inference.py
ttoosi's picture
Upload 2 files
2e0c15c verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.models.resnet import ResNet50_Weights
from PIL import Image
import numpy as np
import os
import requests
import time
import copy
from collections import OrderedDict
from pathlib import Path
# Check for available hardware acceleration
if torch.cuda.is_available():
device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = torch.device("mps") # Use Apple Metal Performance Shaders for M-series Macs
else:
device = torch.device("cpu")
print(f"Using device: {device}")
# Constants
MODEL_URLS = {
'resnet50_robust': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps3.ckpt',
'resnet50_standard': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps0.ckpt',
'resnet50_robust_face': 'https://huggingface.co/ttoosi/resnet50_robust_face/blob/main/100_checkpoint.pt'
}
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
# Define the transforms based on whether normalization is on or off
def get_transform(input_size=224, normalize=False, norm_mean=IMAGENET_MEAN, norm_std=IMAGENET_STD):
if normalize:
return transforms.Compose([
transforms.Resize(input_size),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
else:
return transforms.Compose([
transforms.Resize(input_size),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
])
# Default transform without normalization
transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
])
normalize_transform = transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
def extract_middle_layers(model, layer_index):
"""
Extract a subset of the model up to a specific layer.
Args:
model: The neural network model
layer_index: String 'all' for the full model, or a layer identifier (string or int)
For ResNet: integers 0-8 representing specific layers
For ViT: strings like 'encoder.layers.encoder_layer_3'
Returns:
A modified model that outputs features from the specified layer
"""
if isinstance(layer_index, str) and layer_index == 'all':
return model
# Special case for ViT's encoder layers with DataParallel wrapper
if isinstance(layer_index, str) and layer_index.startswith('encoder.layers.encoder_layer_'):
try:
target_layer_idx = int(layer_index.split('_')[-1])
# Create a deep copy of the model to avoid modifying the original
new_model = copy.deepcopy(model)
# For models wrapped in DataParallel
if hasattr(new_model, 'module'):
# Create a subset of encoder layers up to the specified index
encoder_layers = nn.Sequential()
for i in range(target_layer_idx + 1):
layer_name = f"encoder_layer_{i}"
if hasattr(new_model.module.encoder.layers, layer_name):
encoder_layers.add_module(layer_name,
getattr(new_model.module.encoder.layers, layer_name))
# Replace the encoder layers with our truncated version
new_model.module.encoder.layers = encoder_layers
# Remove the heads since we're stopping at the encoder layer
new_model.module.heads = nn.Identity()
return new_model
else:
# Direct model access (not DataParallel)
encoder_layers = nn.Sequential()
for i in range(target_layer_idx + 1):
layer_name = f"encoder_layer_{i}"
if hasattr(new_model.encoder.layers, layer_name):
encoder_layers.add_module(layer_name,
getattr(new_model.encoder.layers, layer_name))
# Replace the encoder layers with our truncated version
new_model.encoder.layers = encoder_layers
# Remove the heads since we're stopping at the encoder layer
new_model.heads = nn.Identity()
return new_model
except (ValueError, IndexError) as e:
raise ValueError(f"Invalid ViT layer specification: {layer_index}. Error: {e}")
# Handling for ViT whole blocks
elif hasattr(model, 'blocks') or (hasattr(model, 'module') and hasattr(model.module, 'blocks')):
# Check for DataParallel wrapper
base_model = model.module if hasattr(model, 'module') else model
# Create a deep copy to avoid modifying the original
new_model = copy.deepcopy(model)
base_new_model = new_model.module if hasattr(new_model, 'module') else new_model
# Add the desired number of transformer blocks
if isinstance(layer_index, int):
# Truncate the blocks
base_new_model.blocks = base_new_model.blocks[:layer_index+1]
return new_model
else:
# Original ResNet/VGG handling
modules = list(model.named_children())
print(f"DEBUG - extract_middle_layers - Looking for '{layer_index}' in {[name for name, _ in modules]}")
cutoff_idx = next((i for i, (name, _) in enumerate(modules)
if name == str(layer_index)), None)
if cutoff_idx is not None:
# Keep modules up to and including the target
new_model = nn.Sequential(OrderedDict(modules[:cutoff_idx+1]))
return new_model
else:
raise ValueError(f"Module {layer_index} not found in model")
# Get ImageNet labels
def get_imagenet_labels():
url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
response = requests.get(url)
if response.status_code == 200:
return response.json()
else:
raise RuntimeError("Failed to fetch ImageNet labels")
# Download model if needed
def download_model(model_type):
if model_type not in MODEL_URLS or MODEL_URLS[model_type] is None:
return None # Use PyTorch's pretrained model
# Handle special case for face model
if model_type == 'resnet50_robust_face':
model_path = Path("models/resnet50_robust_face_100_checkpoint.pt")
else:
model_path = Path(f"models/{model_type}.pt")
if not model_path.exists():
print(f"Downloading {model_type} model...")
url = MODEL_URLS[model_type]
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Model downloaded and saved to {model_path}")
else:
raise RuntimeError(f"Failed to download model: {response.status_code}")
return model_path
class NormalizeByChannelMeanStd(nn.Module):
def __init__(self, mean, std):
super(NormalizeByChannelMeanStd, self).__init__()
if not isinstance(mean, torch.Tensor):
mean = torch.tensor(mean)
if not isinstance(std, torch.Tensor):
std = torch.tensor(std)
self.register_buffer("mean", mean)
self.register_buffer("std", std)
def forward(self, tensor):
return self.normalize_fn(tensor, self.mean, self.std)
def normalize_fn(self, tensor, mean, std):
"""Differentiable version of torchvision.functional.normalize"""
# here we assume the color channel is at dim=1
mean = mean[None, :, None, None]
std = std[None, :, None, None]
return tensor.sub(mean).div(std)
class InferStep:
def __init__(self, orig_image, eps, step_size):
self.orig_image = orig_image
self.eps = eps
self.step_size = step_size
def project(self, x):
diff = x - self.orig_image
diff = torch.clamp(diff, -self.eps, self.eps)
return torch.clamp(self.orig_image + diff, 0, 1)
def step(self, x, grad):
l = len(x.shape) - 1
grad_norm = torch.norm(grad.view(grad.shape[0], -1), dim=1).view(-1, *([1]*l))
scaled_grad = grad / (grad_norm + 1e-10)
return scaled_grad * self.step_size
def get_iterations_to_show(n_itr):
"""Generate a dynamic list of iterations to show based on total iterations."""
if n_itr <= 50:
return [1, 5, 10, 20, 30, 40, 50, n_itr]
elif n_itr <= 100:
return [1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, n_itr]
elif n_itr <= 200:
return [1, 5, 10, 20, 30, 40, 50, 75, 100, 125, 150, 175, 200, n_itr]
elif n_itr <= 500:
return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500, n_itr]
else:
# For very large iterations, show more evenly distributed points
return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500,
int(n_itr*0.6), int(n_itr*0.7), int(n_itr*0.8), int(n_itr*0.9), n_itr]
def get_inference_configs(inference_type='IncreaseConfidence', eps=0.5, n_itr=50, step_size=1.0):
"""Generate inference configuration with customizable parameters.
Args:
inference_type (str): Type of inference ('IncreaseConfidence' or 'Prior-Guided Drift Diffusion')
eps (float): Maximum perturbation size
n_itr (int): Number of iterations
step_size (float): Step size for each iteration
"""
# Base configuration common to all inference types
config = {
'loss_infer': inference_type, # How to guide the optimization
'n_itr': n_itr, # Number of iterations
'eps': eps, # Maximum perturbation size
'step_size': step_size, # Step size for each iteration
'diffusion_noise_ratio': 0.0, # No diffusion noise
'initial_inference_noise_ratio': 0.0, # No initial noise
'top_layer': 'all', # Use all layers of the model
'inference_normalization': False, # Apply normalization during inference
'recognition_normalization': False, # Apply normalization during recognition
'iterations_to_show': get_iterations_to_show(n_itr), # Dynamic iterations to visualize
'misc_info': {'keep_grads': False} # Additional configuration
}
# Customize based on inference type
if inference_type == 'IncreaseConfidence':
config['loss_function'] = 'CE' # Cross Entropy
elif inference_type == 'Prior-Guided Drift Diffusion':
config['loss_function'] = 'MSE' # Mean Square Error
config['initial_inference_noise_ratio'] = 0.05 # Initial noise for diffusion
config['diffusion_noise_ratio'] = 0.01 # Add noise during diffusion
elif inference_type == 'GradModulation':
config['loss_function'] = 'CE' # Cross Entropy
config['misc_info']['grad_modulation'] = 0.5 # Gradient modulation strength
elif inference_type == 'CompositionalFusion':
config['loss_function'] = 'CE' # Cross Entropy
config['misc_info']['positive_classes'] = [] # Classes to maximize
config['misc_info']['negative_classes'] = [] # Classes to minimize
return config
class GenerativeInferenceModel:
def __init__(self):
self.models = {}
self.normalizer = NormalizeByChannelMeanStd(IMAGENET_MEAN, IMAGENET_STD).to(device)
self.labels = get_imagenet_labels()
def verify_model_integrity(self, model, model_type):
"""
Verify model integrity by running a test input through it.
Returns whether the model passes basic integrity check.
"""
try:
print(f"\n=== Running model integrity check for {model_type} ===")
# Create a deterministic test input directly on the correct device
test_input = torch.zeros(1, 3, 224, 224, device=device)
test_input[0, 0, 100:124, 100:124] = 0.5 # Red square
# Run forward pass
with torch.no_grad():
output = model(test_input)
# Check output shape
if output.shape != (1, 1000):
print(f"❌ Unexpected output shape: {output.shape}, expected (1, 1000)")
return False
# Get top prediction
probs = torch.nn.functional.softmax(output, dim=1)
confidence, prediction = torch.max(probs, 1)
# Calculate basic statistics on output
mean = output.mean().item()
std = output.std().item()
min_val = output.min().item()
max_val = output.max().item()
print(f"Model integrity check results:")
print(f"- Output shape: {output.shape}")
print(f"- Top prediction: Class {prediction.item()} with {confidence.item()*100:.2f}% confidence")
print(f"- Output statistics: mean={mean:.3f}, std={std:.3f}, min={min_val:.3f}, max={max_val:.3f}")
# Basic sanity checks
if torch.isnan(output).any():
print("❌ Model produced NaN outputs")
return False
if output.std().item() < 0.1:
print("⚠️ Low output variance, model may not be discriminative")
print("✅ Model passes basic integrity check")
return True
except Exception as e:
print(f"❌ Model integrity check failed with error: {e}")
# Rather than failing completely, we'll continue
return True
def load_model(self, model_type):
"""Load model from checkpoint or use pretrained model."""
if model_type in self.models:
print(f"Using cached {model_type} model")
return self.models[model_type]
# Record loading time for performance analysis
start_time = time.time()
model_path = download_model(model_type)
# Create a sequential model with normalizer and ResNet50
resnet = models.resnet50()
model = nn.Sequential(
self.normalizer, # Normalizer is part of the model sequence
resnet
)
# Load the model checkpoint
if model_path:
print(f"Loading {model_type} model from {model_path}...")
try:
checkpoint = torch.load(model_path, map_location=device)
# Print checkpoint structure for better understanding
print("\n=== Analyzing checkpoint structure ===")
if isinstance(checkpoint, dict):
print(f"Checkpoint contains keys: {list(checkpoint.keys())}")
# Examine 'model' structure if it exists
if 'model' in checkpoint and isinstance(checkpoint['model'], dict):
model_dict = checkpoint['model']
# Get sample of keys to understand structure
first_keys = list(model_dict.keys())[:5]
print(f"'model' contains keys like: {first_keys}")
# Check for common prefixes in the model dict
prefixes = set()
for key in list(model_dict.keys())[:100]: # Check first 100 keys
parts = key.split('.')
if len(parts) > 1:
prefixes.add(parts[0])
if prefixes:
print(f"Common prefixes in model dict: {prefixes}")
else:
print(f"Checkpoint is not a dictionary, but a {type(checkpoint)}")
# Handle different checkpoint formats
if 'model' in checkpoint:
# Format from madrylab robust models
state_dict = checkpoint['model']
print("Using 'model' key from checkpoint")
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
print("Using 'state_dict' key from checkpoint")
else:
# Direct state dict
state_dict = checkpoint
print("Using checkpoint directly as state_dict")
# Handle prefix in state dict keys for ResNet part
resnet_state_dict = {}
prefixes_to_try = ['', 'module.', 'model.', 'attacker.model.']
resnet_keys = set(resnet.state_dict().keys())
# First check if we can find keys directly in the attacker.model path
print("\n=== Phase 1: Checking for specific model structures ===")
# Check for 'module.model' structure (seen in actual checkpoint)
module_model_keys = [key for key in state_dict.keys() if key.startswith('module.model.')]
if module_model_keys:
print(f"Found 'module.model' structure with {len(module_model_keys)} parameters")
# Extract all parameters from module.model
for source_key, value in state_dict.items():
if source_key.startswith('module.model.'):
target_key = source_key[len('module.model.'):]
resnet_state_dict[target_key] = value
print(f"Extracted {len(resnet_state_dict)} parameters from module.model")
# Check for 'attacker.model' structure
attacker_model_keys = [key for key in state_dict.keys() if key.startswith('attacker.model.')]
if attacker_model_keys:
print(f"Found 'attacker.model' structure with {len(attacker_model_keys)} parameters")
# Extract all parameters from attacker.model
for source_key, value in state_dict.items():
if source_key.startswith('attacker.model.'):
target_key = source_key[len('attacker.model.'):]
resnet_state_dict[target_key] = value
print(f"Extracted {len(resnet_state_dict)} parameters from attacker.model")
# Check if 'model' (not attacker.model) exists as a fallback
model_keys = [key for key in state_dict.keys() if key.startswith('model.') and not key.startswith('attacker.model.')]
if model_keys and len(resnet_state_dict) < len(resnet_keys):
print(f"Found additional 'model.' structure with {len(model_keys)} parameters")
# Try to complete missing parameters
for source_key, value in state_dict.items():
if source_key.startswith('model.'):
target_key = source_key[len('model.'):]
if target_key in resnet_keys and target_key not in resnet_state_dict:
resnet_state_dict[target_key] = value
else:
# Check for other known structures
structure_found = False
# Check for 'model.' prefix
model_keys = [key for key in state_dict.keys() if key.startswith('model.')]
if model_keys:
print(f"Found 'model.' structure with {len(model_keys)} parameters")
for source_key, value in state_dict.items():
if source_key.startswith('model.'):
target_key = source_key[len('model.'):]
resnet_state_dict[target_key] = value
structure_found = True
# Check for ResNet parameters at the top level
top_level_resnet_keys = 0
for key in resnet_keys:
if key in state_dict:
top_level_resnet_keys += 1
if top_level_resnet_keys > 0:
print(f"Found {top_level_resnet_keys} ResNet parameters at top level")
for target_key in resnet_keys:
if target_key in state_dict:
resnet_state_dict[target_key] = state_dict[target_key]
structure_found = True
# If no structure was recognized, try the prefix mapping approach
if not structure_found:
print("No standard model structure found, trying prefix mappings...")
for target_key in resnet_keys:
for prefix in prefixes_to_try:
source_key = prefix + target_key
if source_key in state_dict:
resnet_state_dict[target_key] = state_dict[source_key]
break
# If we still can't find enough keys, try a final approach of removing prefixes
if len(resnet_state_dict) < len(resnet_keys):
print(f"Found only {len(resnet_state_dict)}/{len(resnet_keys)} parameters, trying prefix removal...")
# Track matches found through prefix removal
prefix_matches = {prefix: 0 for prefix in ['module.', 'model.', 'attacker.model.', 'attacker.']}
layer_matches = {} # Track matches by layer type
# Count parameter keys by layer type for analysis
for key in resnet_keys:
layer_name = key.split('.')[0] if '.' in key else key
if layer_name not in layer_matches:
layer_matches[layer_name] = {'total': 0, 'matched': 0}
layer_matches[layer_name]['total'] += 1
# Try keys with common prefixes
for source_key, value in state_dict.items():
# Skip if already found
target_key = source_key
matched_prefix = None
# Try removing various prefixes
for prefix in ['module.', 'model.', 'attacker.model.', 'attacker.']:
if source_key.startswith(prefix):
target_key = source_key[len(prefix):]
matched_prefix = prefix
break
# If the target key is in the ResNet keys, add it to the state dict
if target_key in resnet_keys and target_key not in resnet_state_dict:
resnet_state_dict[target_key] = value
# Update match statistics
if matched_prefix:
prefix_matches[matched_prefix] += 1
# Update layer matches
layer_name = target_key.split('.')[0] if '.' in target_key else target_key
if layer_name in layer_matches:
layer_matches[layer_name]['matched'] += 1
# Print detailed prefix removal statistics
print("\n=== Prefix Removal Statistics ===")
total_matches = sum(prefix_matches.values())
print(f"Total parameters matched through prefix removal: {total_matches}/{len(resnet_keys)} ({(total_matches/len(resnet_keys))*100:.1f}%)")
# Show matches by prefix
print("\nMatches by prefix:")
for prefix, count in sorted(prefix_matches.items(), key=lambda x: x[1], reverse=True):
if count > 0:
print(f" {prefix}: {count} parameters")
# Show matches by layer type
print("\nMatches by layer type:")
for layer, stats in sorted(layer_matches.items(), key=lambda x: x[1]['total'], reverse=True):
match_percent = (stats['matched'] / stats['total']) * 100 if stats['total'] > 0 else 0
print(f" {layer}: {stats['matched']}/{stats['total']} ({match_percent:.1f}%)")
# Check for specific important layers (conv1, layer1, etc.)
critical_layers = ['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'layer4', 'fc']
print("\nStatus of critical layers:")
for layer in critical_layers:
if layer in layer_matches:
match_percent = (layer_matches[layer]['matched'] / layer_matches[layer]['total']) * 100
status = "✅ COMPLETE" if layer_matches[layer]['matched'] == layer_matches[layer]['total'] else "⚠️ INCOMPLETE"
print(f" {layer}: {layer_matches[layer]['matched']}/{layer_matches[layer]['total']} ({match_percent:.1f}%) - {status}")
else:
print(f" {layer}: Not found in model")
# Load the ResNet state dict
if resnet_state_dict:
try:
# Use strict=False to allow missing keys
result = resnet.load_state_dict(resnet_state_dict, strict=False)
missing_keys, unexpected_keys = result
# Generate detailed information with better formatting
loading_report = []
loading_report.append(f"\n===== MODEL LOADING REPORT: {model_type} =====")
loading_report.append(f"Total parameters in checkpoint: {len(resnet_state_dict):,}")
loading_report.append(f"Total parameters in model: {len(resnet.state_dict()):,}")
loading_report.append(f"Missing keys: {len(missing_keys):,} parameters")
loading_report.append(f"Unexpected keys: {len(unexpected_keys):,} parameters")
# Calculate percentage of parameters loaded
loaded_keys = set(resnet_state_dict.keys()) - set(unexpected_keys)
loaded_percent = (len(loaded_keys) / len(resnet.state_dict())) * 100
# Determine loading success status
if loaded_percent >= 99.5:
status = "✅ COMPLETE - All important parameters loaded"
elif loaded_percent >= 90:
status = "🟡 PARTIAL - Most parameters loaded, should still function"
elif loaded_percent >= 50:
status = "⚠️ INCOMPLETE - Many parameters missing, may not function properly"
else:
status = "❌ FAILED - Critical parameters missing, will not function properly"
loading_report.append(f"Successfully loaded: {len(loaded_keys):,} parameters ({loaded_percent:.1f}%)")
loading_report.append(f"Loading status: {status}")
# If loading is severely incomplete, fall back to PyTorch's pretrained model
if loaded_percent < 50:
loading_report.append("\n⚠️ WARNING: Loading from checkpoint is too incomplete.")
loading_report.append("⚠️ Falling back to PyTorch's pretrained model to avoid broken inference.")
# Create a new ResNet model with pretrained weights
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
model = nn.Sequential(self.normalizer, resnet)
loading_report.append("✅ Successfully loaded PyTorch's pretrained ResNet50 model")
# Show missing keys by layer type
if missing_keys:
loading_report.append("\nMissing keys by layer type:")
layer_types = {}
for key in missing_keys:
# Extract layer type (e.g., 'conv', 'bn', 'layer1', etc.)
parts = key.split('.')
if len(parts) > 0:
layer_type = parts[0]
if layer_type not in layer_types:
layer_types[layer_type] = 0
layer_types[layer_type] += 1
# Add counts by layer type
for layer_type, count in sorted(layer_types.items(), key=lambda x: x[1], reverse=True):
loading_report.append(f" {layer_type}: {count:,} parameters")
loading_report.append("\nFirst 10 missing keys:")
for i, key in enumerate(sorted(missing_keys)[:10]):
loading_report.append(f" {i+1}. {key}")
# Show unexpected keys if any
if unexpected_keys:
loading_report.append("\nFirst 10 unexpected keys:")
for i, key in enumerate(sorted(unexpected_keys)[:10]):
loading_report.append(f" {i+1}. {key}")
loading_report.append("========================================")
# Convert report to string and print it
report_text = "\n".join(loading_report)
print(report_text)
# Also save to a file for reference
os.makedirs("logs", exist_ok=True)
with open(f"logs/model_loading_{model_type}.log", "w") as f:
f.write(report_text)
# Look for normalizer parameters as well
if any(key.startswith('attacker.normalize.') for key in state_dict.keys()):
norm_state_dict = {}
for key, value in state_dict.items():
if key.startswith('attacker.normalize.'):
norm_key = key[len('attacker.normalize.'):]
norm_state_dict[norm_key] = value
if norm_state_dict:
try:
self.normalizer.load_state_dict(norm_state_dict, strict=False)
print("Successfully loaded normalizer parameters")
except Exception as e:
print(f"Warning: Could not load normalizer parameters: {e}")
except Exception as e:
print(f"Warning: Error loading ResNet parameters: {e}")
# Fall back to loading without normalizer
model = resnet # Use just the ResNet model without normalizer
except Exception as e:
print(f"Error loading model checkpoint: {e}")
# Fallback to PyTorch's pretrained model
print("Falling back to PyTorch's pretrained model")
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
model = nn.Sequential(self.normalizer, resnet)
else:
# Fallback to PyTorch's pretrained model
print("No checkpoint available, using PyTorch's pretrained model")
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
model = nn.Sequential(self.normalizer, resnet)
model = model.to(device)
model.eval() # Set to evaluation mode
# Verify model integrity
self.verify_model_integrity(model, model_type)
# Store the model for future use
self.models[model_type] = model
end_time = time.time()
load_time = end_time - start_time
print(f"Model {model_type} loaded in {load_time:.2f} seconds")
return model
def inference(self, image, model_type, config):
"""Run generative inference on the image."""
# Time the entire inference process
inference_start = time.time()
# Load model if not already loaded
model = self.load_model(model_type)
# Check if image is a file path
if isinstance(image, str):
if os.path.exists(image):
image = Image.open(image).convert('RGB')
else:
raise ValueError(f"Image path does not exist: {image}")
elif isinstance(image, torch.Tensor):
raise ValueError(f"Image type {type(image)}, looks like already a transformed tensor")
# Prepare image tensor - match original code's conditional transform
load_start = time.time()
use_norm = config['inference_normalization'] == 'on'
custom_transform = get_transform(
input_size=224,
normalize=use_norm,
norm_mean=IMAGENET_MEAN,
norm_std=IMAGENET_STD
)
# Special handling for GradModulation as in original
if config['loss_infer'] == 'GradModulation' and 'misc_info' in config and 'grad_modulation' in config['misc_info']:
grad_modulation = config['misc_info']['grad_modulation']
image_tensor = custom_transform(image).unsqueeze(0).to(device)
image_tensor = image_tensor * (1-grad_modulation) + grad_modulation * torch.randn_like(image_tensor).to(device)
else:
image_tensor = custom_transform(image).unsqueeze(0).to(device)
image_tensor.requires_grad = True
print(f"Image loaded and processed in {time.time() - load_start:.2f} seconds")
# Check model structure
is_sequential = isinstance(model, nn.Sequential)
# Get original predictions
with torch.no_grad():
# If the model is sequential with a normalizer, skip the normalization step
if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
print("Model is sequential with normalization")
# Get the core model part (typically at index 1 in Sequential)
core_model = model[1]
if config['inference_normalization']:
output_original = model(image_tensor) # Model includes normalization
else:
output_original = core_model(image_tensor) # Model includes normalization
else:
print("Model is not sequential with normalization")
# Use manual normalization for non-sequential models
if config['inference_normalization']:
normalized_tensor = normalize_transform(image_tensor)
output_original = model(normalized_tensor)
else:
output_original = model(image_tensor)
core_model = model
probs_orig = F.softmax(output_original, dim=1)
conf_orig, classes_orig = torch.max(probs_orig, 1)
# Get least confident classes
_, least_confident_classes = torch.topk(probs_orig, k=100, largest=False)
# Initialize inference step
infer_step = InferStep(image_tensor, config['eps'], config['step_size'])
# Storage for inference steps
# Create a new tensor that requires gradients
x = image_tensor.clone().detach().requires_grad_(True)
all_steps = [image_tensor[0].detach().cpu()]
# For Prior-Guided Drift Diffusion, extract selected layer and initialize with noisy features
noisy_features = None
layer_model = None
if config['loss_infer'] == 'Prior-Guided Drift Diffusion':
print(f"Setting up Prior-Guided Drift Diffusion with layer {config['top_layer']} and noise {config['initial_inference_noise_ratio']}...")
# Extract model up to the specified layer
try:
# Start by finding the actual model to use
base_model = model
# Handle DataParallel wrapper if present
if hasattr(base_model, 'module'):
base_model = base_model.module
# Log the initial model structure
print(f"DEBUG - Initial model structure: {type(base_model)}")
# If we have a Sequential model (which is likely our normalizer + model structure)
if isinstance(base_model, nn.Sequential):
print(f"DEBUG - Sequential model with {len(list(base_model.children()))} children")
# If this is our NormalizeByChannelMeanStd + ResNet pattern
if len(list(base_model.children())) >= 2:
# The actual ResNet model is the second component (index 1)
actual_model = list(base_model.children())[1]
print(f"DEBUG - Using ResNet component: {type(actual_model)}")
print(f"DEBUG - Available layers: {[name for name, _ in actual_model.named_children()]}")
# Extract from the actual ResNet
layer_model = extract_middle_layers(actual_model, config['top_layer'])
else:
# Just a single component Sequential
layer_model = extract_middle_layers(base_model, config['top_layer'])
else:
# Not Sequential, might be direct model
print(f"DEBUG - Available layers: {[name for name, _ in base_model.named_children()]}")
layer_model = extract_middle_layers(base_model, config['top_layer'])
print(f"Successfully extracted model up to layer: {config['top_layer']}")
except ValueError as e:
print(f"Layer extraction failed: {e}. Using full model.")
layer_model = model
# Add noise to the image - exactly match original code
added_noise = config['initial_inference_noise_ratio'] * torch.randn_like(image_tensor).to(device)
noisy_image_tensor = image_tensor + added_noise
# Compute noisy features - simplified to match original code
noisy_features = layer_model(noisy_image_tensor)
print(f"Noisy features computed for Prior-Guided Drift Diffusion target with shape: {noisy_features.shape if hasattr(noisy_features, 'shape') else 'unknown'}")
# Main inference loop
print(f"Starting inference loop with {config['n_itr']} iterations for {config['loss_infer']}...")
loop_start = time.time()
for i in range(config['n_itr']):
# Reset gradients
x.grad = None
# Forward pass - use layer_model for Prior-Guided Drift Diffusion, full model otherwise
if config['loss_infer'] == 'Prior-Guided Drift Diffusion' and layer_model is not None:
# Use the extracted layer model for Prior-Guided Drift Diffusion
# In original code, normalization is handled at transform time, not during forward pass
output = layer_model(x)
else:
# Standard forward pass with full model
# Simplified to match original code's approach
output = model(x)
# Calculate loss and gradients based on inference type
try:
if config['loss_infer'] == 'Prior-Guided Drift Diffusion':
# Use MSE loss to match the noisy features
assert config['loss_function'] == 'MSE', "Reverse Diffusion loss function must be MSE"
if noisy_features is not None:
loss = F.mse_loss(output, noisy_features)
grad = torch.autograd.grad(loss, x)[0] # Removed retain_graph=True to match original
else:
raise ValueError("Noisy features not computed for Prior-Guided Drift Diffusion")
else: # Default 'IncreaseConfidence' approach
# Get the least confident classes
num_classes = min(10, least_confident_classes.size(1))
target_classes = least_confident_classes[0, :num_classes]
# Create targets for least confident classes
targets = torch.tensor([idx.item() for idx in target_classes], device=device)
# Use a combined loss to increase confidence
loss = 0
for target in targets:
# Create one-hot target
one_hot = torch.zeros_like(output)
one_hot[0, target] = 1
# Use loss to maximize confidence
loss = loss + F.mse_loss(F.softmax(output, dim=1), one_hot)
grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
if grad is None:
print("Warning: Direct gradient calculation failed")
# Fall back to random perturbation
random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
x = infer_step.project(x + random_noise)
else:
# Update image with gradient - do this exactly as in original code
adjusted_grad = infer_step.step(x, grad)
# Add diffusion noise if specified
diffusion_noise = config['diffusion_noise_ratio'] * torch.randn_like(x).to(device)
# Apply gradient and noise in one operation before projecting, exactly as in original
x = infer_step.project(x.clone() + adjusted_grad + diffusion_noise)
except Exception as e:
print(f"Error in gradient calculation: {e}")
# Fall back to random perturbation - match original code
random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
x = infer_step.project(x.clone() + random_noise)
# Store step if in iterations_to_show
if i+1 in config['iterations_to_show'] or i+1 == config['n_itr']:
all_steps.append(x[0].detach().cpu())
# Print some info about the inference
with torch.no_grad():
if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
if config['inference_normalization']:
final_output = model(x)
else:
final_output = core_model(x)
else:
if config['inference_normalization']:
normalized_x = normalize_transform(x)
final_output = model(normalized_x)
else:
final_output = model(x)
final_probs = F.softmax(final_output, dim=1)
final_conf, final_classes = torch.max(final_probs, 1)
# Calculate timing information
loop_time = time.time() - loop_start
total_time = time.time() - inference_start
avg_iter_time = loop_time / config['n_itr'] if config['n_itr'] > 0 else 0
print(f"Original top class: {classes_orig.item()} ({conf_orig.item():.4f})")
print(f"Final top class: {final_classes.item()} ({final_conf.item():.4f})")
print(f"Inference loop completed in {loop_time:.2f} seconds ({avg_iter_time:.4f} sec/iteration)")
print(f"Total inference time: {total_time:.2f} seconds")
# Return results in format compatible with both old and new code
return {
'final_image': x[0].detach().cpu(),
'steps': all_steps,
'original_class': classes_orig.item(),
'original_confidence': conf_orig.item(),
'final_class': final_classes.item(),
'final_confidence': final_conf.item()
}
# Utility function to show inference steps
def show_inference_steps(steps, figsize=(15, 10)):
import matplotlib.pyplot as plt
n_steps = len(steps)
fig, axes = plt.subplots(1, n_steps, figsize=figsize)
for i, step_img in enumerate(steps):
img = step_img.permute(1, 2, 0).numpy()
axes[i].imshow(img)
axes[i].set_title(f"Step {i}")
axes[i].axis('off')
plt.tight_layout()
return fig