clipL336_TTR / shared.py
JH-C-k's picture
Add files using upload-large-folder tool
2642b57 verified
raw
history blame
24.4 kB
import os
import subprocess
import re
import random
import matplotlib.pyplot as plt
import json
def get_gpu_memory_usage():
"""Returns a list of GPU memory usage in MB."""
try:
# Run nvidia-smi command and capture the output
result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader'],
stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
# Check if the command was successful
if result.returncode != 0:
raise RuntimeError(f"nvidia-smi command failed with error: {result.stderr}")
# Parse the output to get memory usage
memory_usages = [int(x) for x in result.stdout.strip().split('\n')]
return memory_usages
except Exception as e:
print(f"Error querying GPU memory usage: {e}")
return []
def set_cuda_visible_device():
"""Sets the CUDA_VISIBLE_DEVICES environment variable to the GPU with the smallest memory usage."""
memory_usages = get_gpu_memory_usage()
if not memory_usages:
print("No GPU memory usage data available.")
return
# Find the index of the GPU with the smallest memory usage
min_memory_index = memory_usages.index(min(memory_usages))
# Set the CUDA_VISIBLE_DEVICES environment variable
os.environ["CUDA_VISIBLE_DEVICES"] = str(min_memory_index)
print(f"Set CUDA_VISIBLE_DEVICES to GPU {min_memory_index} with {memory_usages[min_memory_index]} MB used.")
return str(min_memory_index)
os.environ["ASN_ROOT_DIR"] = "/home/nickj/asn/second_order_lens"
os.chdir(os.environ["ASN_ROOT_DIR"])
import numpy as np
import torch
from PIL import Image
import os.path
import argparse
from pathlib import Path
from tqdm import tqdm
from utils.factory import create_model_and_transforms, get_tokenizer
from PIL import Image, ImageDraw
def get_model(model_name = "ViT-B/16", pretrained = "openai", device = "cuda:0"):
torch.multiprocessing.set_sharing_strategy("file_system")
model, _, preprocess = create_model_and_transforms(
model_name, pretrained=pretrained, force_quick_gelu=True,
)
model.to(device)
model.eval()
context_length = model.context_length
vocab_size = model.vocab_size
return {
"model": model,
"model_name": model_name,
"pretrained": pretrained,
"preprocess": preprocess,
"context_length": context_length,
"vocab_size": vocab_size
}
img_path = "/datasets/ilsvrc_2024-01-04_1913/val/n04398044/ILSVRC2012_val_00042447.JPEG"
# img_path = "./sample.jpeg"
def load_images(preprocess, image_folder = "/datasets/ilsvrc/current/val", count = 100, images_only = True):
file_list = []
for root, dirs, files in os.walk(image_folder):
for file in files:
file_list.append(os.path.join(root, file))
if count > len(file_list):
sampled_files = file_list
else:
sampled_files = random.sample(file_list, count)
image_files = []
for filename in sampled_files:
image_files.append(preprocess(Image.open(filename)))
if images_only:
return image_files
else:
return image_files, sampled_files
def calc_neuron_potentials(model, attn_layers = (1, 2), include_layernorm = True):
# Calculates the attention-shifting potential scores for every neuron to the attention heads defined by the given layers (relative to the MLP layer)
embed_dim = model.visual.transformer.resblocks[0].attn.out_proj.in_features
num_heads = model.visual.transformer.resblocks[0].attn.num_heads
head_dim = embed_dim // num_heads
layers = len(model.visual.transformer.resblocks)
results = dict()
for neuron_layer in tqdm(range(layers), desc = "Calculating attention shifting potentials"):
neuron_projection = model.visual.transformer.resblocks[neuron_layer].state_dict()["mlp.c_proj.weight"]
for l_attn in range(min(layers, neuron_layer + attn_layers[0]), min(layers, neuron_layer + attn_layers[1])):
ln_vector = model.visual.transformer.resblocks[l_attn].ln_1.state_dict()["weight"]
attn_matrix = model.visual.transformer.resblocks[l_attn].state_dict()["attn.in_proj_weight"]
W_Q, W_K, W_V = (attn_matrix[:embed_dim].reshape(num_heads, head_dim, -1),
attn_matrix[embed_dim:2*embed_dim].reshape(num_heads, head_dim, -1),
attn_matrix[2*embed_dim:].reshape(num_heads, head_dim, -1))
for head_idx in range(num_heads):
W_Q_h, W_K_h = W_Q[head_idx], W_K[head_idx]
effects = []
for i in range(neuron_projection.shape[1]):
if include_layernorm:
neuron_attn_effect = torch.norm(W_Q_h.T @ W_K_h @ (neuron_projection[:, i] * ln_vector))
else:
neuron_attn_effect = torch.norm(W_Q_h.T @ W_K_h @ neuron_projection[:, i])
effects.append(neuron_attn_effect)
results[(neuron_layer, l_attn, head_idx)] = torch.tensor(effects)
return results
def calc_top_asns(shift_potentials, top_k = 10, per = "layer", layers_away = 1):
num_layers = max([key[1] for key in shift_potentials.keys()]) + 1 # the last layer has no ASNs by definition
num_heads = max([key[2] for key in shift_potentials.keys()])
top_asns = []
for layer in range(num_layers - layers_away):
if per == "layer":
potentials = []
for head_idx in range(num_heads):
potentials.append(shift_potentials[(layer, layer + layers_away, head_idx)])
potentials = torch.max(torch.stack(potentials, dim = 0), dim = 0).values
_, sorted_indices = torch.sort(potentials, descending = True)
top_asns.append(sorted_indices[:top_k].tolist())
elif per == "head":
top_layer_asns = []
for head_idx in range(num_heads):
_, sorted_indices = torch.sort(shift_potentials[(layer, layer + layers_away, head_idx)], descending = True)
top_layer_asns.append(sorted_indices[:top_k].tolist())
top_asns.append(top_layer_asns)
else:
raise ValueError(f"Invalid per value: {per}")
return top_asns
def aggregate_attn_map(attn_map, layer, head):
num_tokens = attn_map.shape[-1]
assert (num_tokens - 1) ** 0.5 % 1 == 0, "num_tokens - 1 is not a perfect square"
num_patches = int((num_tokens - 1) ** 0.5)
aggregate_scores = torch.sum(attn_map[:, layer, head, 1:, 1:], dim = 1).reshape((1, num_patches, num_patches))
return aggregate_scores
def attn_map_cls_token(attn_map, layer, head):
# Gets the attention map for the CLS token
num_tokens = attn_map.shape[-1]
assert (num_tokens - 1) ** 0.5 % 1 == 0, "num_tokens - 1 is not a perfect square"
num_patches = int((num_tokens - 1) ** 0.5)
attn_map_reshaped = attn_map[:, layer, head, 0, 1:].reshape((1, num_patches, num_patches))
return attn_map_reshaped
def visualize_attn_shift(attn_map1, attn_map2, image, display=True, out=None, min_diff=None, max_diff=None):
import matplotlib.pyplot as plt
import numpy as np
# Subtract attn_map1 from attn_map2
diff_map = attn_map2 - attn_map1
# Convert the image to RGBA
image = image.convert("RGBA")
overlay = Image.new("RGBA", image.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
# Calculate the size of each attention block
block_size_x = image.size[0] / diff_map.shape[0]
block_size_y = image.size[1] / diff_map.shape[1]
# Create a colormap
cmap = plt.get_cmap('coolwarm_r') # 'cool' colormap for lighter to darker
# Get the min and max values for scaling the colormap
if max_diff is None:
max_diff = diff_map.max()
if min_diff is None:
min_diff = diff_map.min()
for i in range(diff_map.shape[0]):
for j in range(diff_map.shape[1]):
# Get the color from the colormap
intensity = diff_map[i, j]
normalized_intensity = (intensity - min_diff) / (max_diff - min_diff) # Scale to [0, 1]
rgba_color = cmap(1 - normalized_intensity) # Invert the normalized intensity
color = tuple(int(c * 255) for c in rgba_color[:3]) + (int(rgba_color[3] * 128),)
# Draw the rectangle on the overlay with transparency
draw.rectangle(
[j * block_size_x, i * block_size_y, (j + 1) * block_size_x, (i + 1) * block_size_y],
fill=color # Add transparency to the color
)
# Composite the overlay with the original image
combined = Image.alpha_composite(image, overlay)
if display:
# Display the result
combined.show()
# Show the color scale
plt.figure(figsize=(6, 1))
plt.imshow([np.linspace(min_diff, max_diff, 256)], cmap='coolwarm_r', aspect='auto')
plt.gca().set_visible(False)
plt.colorbar(orientation="horizontal")
plt.show()
if out is not None:
combined.save(out)
return combined
def visualize_attn_shift_binary(attn_map1, attn_map2, image, display=True, out=None):
# Creates a visualization of the attention shift where green = positive, red = negative values.
# This is useful when there are outliers in the difference map causing the middle values around 0 to be messed into one color
# Subtract attn_map1 from attn_map2
diff_map = attn_map2 - attn_map1
# Normalize the difference map to range [0, 1] for visualization
diff_map_normalized = (diff_map - diff_map.min()) / (diff_map.max() - diff_map.min())
# Convert the image to RGBA
image = image.convert("RGBA")
overlay = Image.new("RGBA", image.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
# Calculate the size of each attention block
block_size_x = image.size[0] / diff_map.shape[0]
block_size_y = image.size[1] / diff_map.shape[1]
for i in range(diff_map.shape[0]):
for j in range(diff_map.shape[1]):
# Calculate the color intensity based on the difference
intensity = diff_map_normalized[i, j]
alpha = int(255 * 0.5) # Tone down the alpha to 50%
if diff_map[i, j] > 0:
color = (0, int(255 * intensity), 0, alpha) # Green for positive
else:
color = (int(255 * (1 - intensity)), 0, 0, alpha) # Red for negative
# Draw the rectangle on the overlay
draw.rectangle(
[j * block_size_x, i * block_size_y, (j + 1) * block_size_x, (i + 1) * block_size_y],
fill=color
)
# Composite the overlay with the original image
combined = Image.alpha_composite(image, overlay)
if display:
# Display the result
combined.show()
if out is not None:
combined.save(out)
return combined
def is_outlier(mean, std, value):
return value < mean - 2 * std or value > mean + 2 * std
def get_neuron_activations(images, prs_group, model, device = "cuda:0"):
# Returns neuron activations in shape (num_images, num_layers, num_patches, num_neurons)
random_neuron_acts = []
for image in tqdm(images, desc="Processing images"):
prs_group.reinit()
image_input = image.unsqueeze(0).to(device)
representation = model.encode_image(
image_input, attn_method="head", normalize=False
)
prs_group.finalize()
gelu_outs = prs_group.post_gelu_outputs()
random_neuron_acts.append(gelu_outs)
random_neuron_acts = torch.stack(random_neuron_acts, dim = 0)
return random_neuron_acts
def normalize_array(arr):
min_val = np.min(arr)
max_val = np.max(arr)
# Avoid division by zero if all values are the same
if max_val - min_val == 0:
return np.zeros_like(arr)
normalized_arr = (arr - min_val) / (max_val - min_val)
return normalized_arr
def np_l2(arr1, arr2):
return np.linalg.norm(arr1 - arr2)
def best_class(classifier, representation):
cs = torch.cosine_similarity(classifier, representation.permute(1, 0), dim = 0)
return torch.argmax(cs).item(), cs[torch.argmax(cs).item()].item()
def load_group_attn_shifts(timestamp):
# Load from Supp1B
results_dir = "./results/supp1B"
# dirs = [os.path.join(results_dir, d) for d in os.listdir(results_dir)
# if os.path.isdir(os.path.join(results_dir, d))]
# latest_dir = max(dirs, key=os.path.getmtime)
latest_dir = os.path.join(results_dir, timestamp)
print(f"Using latest results directory: {latest_dir}")
# Load metadata
with open(os.path.join(latest_dir, "metadata.json"), "r") as f:
metadata = json.load(f)
# Load memory-mapped files
attn_maps = np.memmap(os.path.join(latest_dir, "attention_maps.mmap"),
dtype=np.float32,
mode='r',
shape=tuple(metadata["attention_maps_shape"]))
resblocks = np.memmap(os.path.join(latest_dir, "resblocks.mmap"),
dtype=np.float32,
mode='r',
shape=tuple(metadata["resblocks_shape"]))
# Get file list from metadata
file_list = metadata.get("file_list", [])
# Get top_k values
top_k_values = metadata.get("top_k_values", [0])
return {
"attn_maps": attn_maps,
"resblocks": resblocks,
"metadata": metadata,
"file_list": file_list,
"top_k_values": top_k_values,
"num_layers": metadata.get("num_layers", 0),
"num_images": metadata.get("num_images", 0),
"num_heads": metadata.get("num_heads", 0)
}
def load_individual_attn_shifts(timestamp, supp = "supp1D"):
results_dir = f"./results/{supp}"
# dirs = [os.path.join(results_dir, d) for d in os.listdir(results_dir)
# if os.path.isdir(os.path.join(results_dir, d))]
# latest_dir = max(dirs, key=os.path.getmtime)
latest_dir = os.path.join(results_dir, timestamp)
print(f"Using latest results directory: {latest_dir}")
# Load metadata
with open(os.path.join(latest_dir, "metadata.json"), "r") as f:
metadata = json.load(f)
# Load memory-mapped files
attn_maps = np.memmap(os.path.join(latest_dir, "attention_maps.mmap"),
dtype=np.float32,
mode='r',
shape=tuple(metadata["attention_maps_shape"]))
baseline_attn_maps = np.memmap(os.path.join(latest_dir, "baseline_attention_maps.mmap"),
dtype=np.float32,
mode='r',
shape=tuple(metadata["baseline_attention_maps_shape"]))
neuron_activations = np.memmap(os.path.join(latest_dir, "neuron_activations.mmap"),
dtype=np.float32,
mode='r',
shape=tuple(metadata["neuron_activations_shape"]))
baseline_neuron_activations = np.memmap(os.path.join(latest_dir, "baseline_neuron_activations.mmap"),
dtype=np.float32,
mode='r',
shape=tuple(metadata["baseline_neuron_activations_shape"]))
ablated_neurons = np.memmap(os.path.join(latest_dir, "ablated_neurons.mmap"),
dtype=np.float32,
mode='r',
shape=tuple(metadata["ablated_neurons_shape"]))
# Get file list from metadata
file_list = metadata.get("file_list", [])
# Get k value
k = metadata.get("k", 25)
return {
"attn_maps": attn_maps,
"baseline_attn_maps": baseline_attn_maps,
"neuron_activations": neuron_activations,
"baseline_neuron_activations": baseline_neuron_activations,
"ablated_neurons": ablated_neurons,
"metadata": metadata,
"file_list": file_list,
"k": k,
"num_layers": metadata.get("num_layers", 12),
"num_images": metadata.get("num_images", 100),
"model_name": metadata.get("model_name", "ViT-B-16"),
"pretrained": metadata.get("pretrained", "openai")
}
def find_register_neurons_cuda(model, preprocess, prs_group, register_norm_threshold = 30, highest_layer = -1, device = "cuda:0", processed_image_cnt = 500):
num_layers = len(model.visual.transformer.resblocks)
highest_layer = num_layers - 1 if highest_layer == -1 else highest_layer
num_neurons = model.visual.transformer.resblocks[0].mlp.state_dict()["c_proj.weight"].shape[1]
random_images = load_images(preprocess, count=processed_image_cnt)
neuron_scores = torch.zeros((len(random_images), num_layers, num_neurons), device=device)
alignment_scores = torch.zeros((len(random_images), num_layers, num_neurons), device=device)
image_count = 0
for i in tqdm(range(len(random_images)), desc="Processing random images"):
image = random_images[i].unsqueeze(0).to(device)
prs_group.reinit()
with torch.inference_mode():
representation = model.encode_image(
image, attn_method="head", normalize=False
)
prs_group.finalize()
baseline_neuron_acts = prs_group.post_gelu_outputs().to(device)
baseline_resblock_outputs = prs_group.resblock_outputs().to(device)
# Calculate norm map using torch
norm_map = torch.norm(baseline_resblock_outputs[-1], dim=1)
filtered_norms = norm_map.clone()
filtered_norms[filtered_norms < register_norm_threshold] = 0
# Get register locations as a tensor
register_locations = torch.where(filtered_norms > register_norm_threshold)[0]
if len(register_locations) == 0:
continue
image_count += 1
# Process all layers vectorized
for layer in range(num_layers):
# Get absolute activations for all neurons in this layer
act_layer = torch.abs(baseline_neuron_acts[layer]) # Shape: [seq_len, num_neurons]
# Check sparsity condition for all neurons at once
sparse_neurons = torch.sum(act_layer < 0.5, dim=0) >= 0.5 * act_layer.shape[0] # Shape: [num_neurons]
# Skip computation if no neurons meet the condition
if not torch.any(sparse_neurons):
continue
# Get values at register locations for all neurons simultaneously
# This creates a tensor of shape [num_register_locations, num_neurons]
register_values = act_layer[register_locations]
# For neurons that pass sparsity condition, compute mean at register locations
# First, compute mean for all neurons (this is fast)
neuron_means = register_values.mean(dim=0) # Shape: [num_neurons]
# Then zero out means for neurons that don't pass sparsity condition
neuron_means = neuron_means * sparse_neurons.float()
# Store in score tensor
neuron_scores[i, layer] = neuron_means
# Rest of the code remains the same
mean_neuron_scores = neuron_scores[:image_count].mean(dim=0)
mean_alignment_scores = alignment_scores[:image_count].mean(dim=0)
# Flatten and find top values
flattened_scores = mean_neuron_scores.flatten()
sorted_values, sorted_indices = torch.sort(flattened_scores, descending=True)
flattened_alignment = mean_alignment_scores.flatten()
sorted_alignment_values, sorted_alignment_indices = torch.sort(flattened_alignment, descending=True)
# Convert indices to layer/neuron pairs
top_indices = [(idx.item() // num_neurons, idx.item() % num_neurons) for idx in sorted_indices]
top_alignment_indices = [(idx.item() // num_neurons, idx.item() % num_neurons) for idx in sorted_alignment_indices]
register_norms = [
(layer, neuron, sorted_values[i].item())
for i, (layer, neuron) in enumerate(top_indices)
if layer <= highest_layer
]
best_alignment_scores = [
(layer, neuron, sorted_alignment_values[i].item())
for i, (layer, neuron) in enumerate(top_alignment_indices)
if layer <= highest_layer
]
return register_norms, best_alignment_scores
def find_register_neurons(model, preprocess, prs_group, register_norm_threshold = 30, highest_layer = -1, device = "cuda:0", processed_image_cnt = 500):
num_layers = len(model.visual.transformer.resblocks)
highest_layer = num_layers - 1 if highest_layer == -1 else highest_layer
num_neurons = model.visual.transformer.resblocks[0].mlp.state_dict()["c_proj.weight"].shape[1]
random_images = load_images(preprocess, count = processed_image_cnt)
neuron_scores = torch.zeros((len(random_images), num_layers, num_neurons))
for i in tqdm(range(len(random_images)), desc="Processing random images"):
image = random_images[i].unsqueeze(0).to(device)
prs_group.reinit()
with torch.no_grad():
representation = model.encode_image(
image, attn_method="head", normalize=False
)
prs_group.finalize()
# Gather neuron activations and resblock outputs
baseline_neuron_acts = prs_group.post_gelu_outputs().cpu().numpy()
baseline_resblock_outputs = prs_group.resblock_outputs().cpu().numpy()
# Calculate norms of the last resblock outputs. Only consider patches of the activation maps that correspond with registers
norms = np.linalg.norm(baseline_resblock_outputs[-1], axis=1)
norms[norms < register_norm_threshold] = 0
register_locations = np.where(norms > register_norm_threshold)[0]
# register_neurons = []
for layer in range(num_layers):
for neuron in range(num_neurons):
neuron_map = baseline_neuron_acts[layer, :, neuron]
mask = np.zeros_like(neuron_map, dtype=bool)
mask[register_locations] = True
neuron_map[~mask] = 0
if np.any(neuron_map < 0):
continue
# dist = np.linalg.norm(normalize_array(norms) - normalize_array(neuron_map))
# register_neurons.append((layer, neuron, dist.item(), neuron_map[register_locations].mean()))
neuron_scores[i, layer, neuron] = torch.tensor(neuron_map[register_locations].mean())
mean_neuron_scores = neuron_scores.mean(dim=0)
# Flatten the 2D tensor to find global top values
flattened_scores = mean_neuron_scores.flatten()
sorted_values, sorted_indices = torch.sort(flattened_scores, descending=True)
# Convert flat indices back to 2D coordinates (layer, neuron)
top_indices = [(idx.item() // num_neurons, idx.item() % num_neurons) for idx in sorted_indices]
return [(layer, neuron, sorted_values[i].item()) for i, (layer, neuron) in enumerate(top_indices) if layer <= highest_layer]
def plot_attn_maps(attn_maps, image_idx):
num_layers, num_heads, patch_height, patch_width = attn_maps.shape
print(f"Shape of image_shifts: {attn_maps.shape}")
# Create a grid of plots for all layers and heads
fig, axes = plt.subplots(num_layers, num_heads, figsize=(2*num_heads, 2*num_layers))
fig.suptitle(f'Attention Shift Maps for Image #{image_idx}', fontsize=16)
# Import the correct module for make_axes_locatable
from mpl_toolkits.axes_grid1 import make_axes_locatable
# Plot each layer-head combination
for layer in range(num_layers):
# Determine min and max for this layer for consistent colorbar scaling within the layer
layer_vmin = attn_maps[layer].min().item()
layer_vmax = attn_maps[layer].max().item()
for head in range(num_heads):
# Get the current axis (handle both 2D and 1D cases)
if num_layers == 1 and num_heads == 1:
ax = axes
elif num_layers == 1:
ax = axes[head]
elif num_heads == 1:
ax = axes[layer]
else:
ax = axes[layer, head]
# Plot the attention shift map with layer-specific normalization
im = ax.imshow(attn_maps[layer, head], cmap='viridis', vmin=layer_vmin, vmax=layer_vmax)
# Remove ticks for cleaner appearance
ax.set_xticks([])
ax.set_yticks([])
# Add layer and head labels only on the edges
if head == 0:
ax.set_ylabel(f'Layer {layer}')
if layer == num_layers-1:
ax.set_xlabel(f'Head {head}')
# Add a colorbar for each layer (only once per row)
if head == num_heads-1:
# Create a colorbar that's properly sized relative to the plot
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
# Adjust layout to make room for the colorbars
plt.tight_layout()
return plt
def calculate_iou(output, target):
intersection = output * (output == target)
area_inter = intersection.sum().item()
area_pred = output.sum().item()
area_target = target.sum().item()
union = area_pred + area_target - area_inter
iou = area_inter / union
return area_inter, union, iou
def calculate_pixel_accuracy(output, target):
correct = output * (output == target)
correct = correct.sum().item()
total = target.sum().item()
return correct, total, correct / total