|
import os |
|
import subprocess |
|
import re |
|
import random |
|
import matplotlib.pyplot as plt |
|
import json |
|
def get_gpu_memory_usage(): |
|
"""Returns a list of GPU memory usage in MB.""" |
|
try: |
|
|
|
result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader'], |
|
stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) |
|
|
|
|
|
if result.returncode != 0: |
|
raise RuntimeError(f"nvidia-smi command failed with error: {result.stderr}") |
|
|
|
|
|
memory_usages = [int(x) for x in result.stdout.strip().split('\n')] |
|
return memory_usages |
|
except Exception as e: |
|
print(f"Error querying GPU memory usage: {e}") |
|
return [] |
|
|
|
def set_cuda_visible_device(): |
|
"""Sets the CUDA_VISIBLE_DEVICES environment variable to the GPU with the smallest memory usage.""" |
|
memory_usages = get_gpu_memory_usage() |
|
|
|
if not memory_usages: |
|
print("No GPU memory usage data available.") |
|
return |
|
|
|
|
|
min_memory_index = memory_usages.index(min(memory_usages)) |
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(min_memory_index) |
|
print(f"Set CUDA_VISIBLE_DEVICES to GPU {min_memory_index} with {memory_usages[min_memory_index]} MB used.") |
|
|
|
return str(min_memory_index) |
|
|
|
os.environ["ASN_ROOT_DIR"] = "/home/nickj/asn/second_order_lens" |
|
os.chdir(os.environ["ASN_ROOT_DIR"]) |
|
|
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
import os.path |
|
import argparse |
|
from pathlib import Path |
|
|
|
from tqdm import tqdm |
|
from utils.factory import create_model_and_transforms, get_tokenizer |
|
from PIL import Image, ImageDraw |
|
|
|
def get_model(model_name = "ViT-B/16", pretrained = "openai", device = "cuda:0"): |
|
torch.multiprocessing.set_sharing_strategy("file_system") |
|
model, _, preprocess = create_model_and_transforms( |
|
model_name, pretrained=pretrained, force_quick_gelu=True, |
|
) |
|
model.to(device) |
|
model.eval() |
|
context_length = model.context_length |
|
vocab_size = model.vocab_size |
|
|
|
return { |
|
"model": model, |
|
"model_name": model_name, |
|
"pretrained": pretrained, |
|
"preprocess": preprocess, |
|
"context_length": context_length, |
|
"vocab_size": vocab_size |
|
} |
|
|
|
img_path = "/datasets/ilsvrc_2024-01-04_1913/val/n04398044/ILSVRC2012_val_00042447.JPEG" |
|
|
|
def load_images(preprocess, image_folder = "/datasets/ilsvrc/current/val", count = 100, images_only = True): |
|
file_list = [] |
|
|
|
for root, dirs, files in os.walk(image_folder): |
|
for file in files: |
|
file_list.append(os.path.join(root, file)) |
|
|
|
if count > len(file_list): |
|
sampled_files = file_list |
|
else: |
|
sampled_files = random.sample(file_list, count) |
|
|
|
image_files = [] |
|
|
|
for filename in sampled_files: |
|
image_files.append(preprocess(Image.open(filename))) |
|
if images_only: |
|
return image_files |
|
else: |
|
return image_files, sampled_files |
|
|
|
def calc_neuron_potentials(model, attn_layers = (1, 2), include_layernorm = True): |
|
|
|
|
|
embed_dim = model.visual.transformer.resblocks[0].attn.out_proj.in_features |
|
num_heads = model.visual.transformer.resblocks[0].attn.num_heads |
|
head_dim = embed_dim // num_heads |
|
layers = len(model.visual.transformer.resblocks) |
|
|
|
results = dict() |
|
|
|
for neuron_layer in tqdm(range(layers), desc = "Calculating attention shifting potentials"): |
|
neuron_projection = model.visual.transformer.resblocks[neuron_layer].state_dict()["mlp.c_proj.weight"] |
|
for l_attn in range(min(layers, neuron_layer + attn_layers[0]), min(layers, neuron_layer + attn_layers[1])): |
|
ln_vector = model.visual.transformer.resblocks[l_attn].ln_1.state_dict()["weight"] |
|
attn_matrix = model.visual.transformer.resblocks[l_attn].state_dict()["attn.in_proj_weight"] |
|
W_Q, W_K, W_V = (attn_matrix[:embed_dim].reshape(num_heads, head_dim, -1), |
|
attn_matrix[embed_dim:2*embed_dim].reshape(num_heads, head_dim, -1), |
|
attn_matrix[2*embed_dim:].reshape(num_heads, head_dim, -1)) |
|
|
|
for head_idx in range(num_heads): |
|
W_Q_h, W_K_h = W_Q[head_idx], W_K[head_idx] |
|
effects = [] |
|
for i in range(neuron_projection.shape[1]): |
|
if include_layernorm: |
|
neuron_attn_effect = torch.norm(W_Q_h.T @ W_K_h @ (neuron_projection[:, i] * ln_vector)) |
|
else: |
|
neuron_attn_effect = torch.norm(W_Q_h.T @ W_K_h @ neuron_projection[:, i]) |
|
effects.append(neuron_attn_effect) |
|
|
|
results[(neuron_layer, l_attn, head_idx)] = torch.tensor(effects) |
|
return results |
|
|
|
def calc_top_asns(shift_potentials, top_k = 10, per = "layer", layers_away = 1): |
|
num_layers = max([key[1] for key in shift_potentials.keys()]) + 1 |
|
num_heads = max([key[2] for key in shift_potentials.keys()]) |
|
|
|
top_asns = [] |
|
for layer in range(num_layers - layers_away): |
|
if per == "layer": |
|
potentials = [] |
|
for head_idx in range(num_heads): |
|
potentials.append(shift_potentials[(layer, layer + layers_away, head_idx)]) |
|
potentials = torch.max(torch.stack(potentials, dim = 0), dim = 0).values |
|
_, sorted_indices = torch.sort(potentials, descending = True) |
|
top_asns.append(sorted_indices[:top_k].tolist()) |
|
elif per == "head": |
|
top_layer_asns = [] |
|
for head_idx in range(num_heads): |
|
_, sorted_indices = torch.sort(shift_potentials[(layer, layer + layers_away, head_idx)], descending = True) |
|
top_layer_asns.append(sorted_indices[:top_k].tolist()) |
|
top_asns.append(top_layer_asns) |
|
else: |
|
raise ValueError(f"Invalid per value: {per}") |
|
return top_asns |
|
|
|
def aggregate_attn_map(attn_map, layer, head): |
|
num_tokens = attn_map.shape[-1] |
|
assert (num_tokens - 1) ** 0.5 % 1 == 0, "num_tokens - 1 is not a perfect square" |
|
|
|
num_patches = int((num_tokens - 1) ** 0.5) |
|
aggregate_scores = torch.sum(attn_map[:, layer, head, 1:, 1:], dim = 1).reshape((1, num_patches, num_patches)) |
|
return aggregate_scores |
|
|
|
def attn_map_cls_token(attn_map, layer, head): |
|
|
|
num_tokens = attn_map.shape[-1] |
|
assert (num_tokens - 1) ** 0.5 % 1 == 0, "num_tokens - 1 is not a perfect square" |
|
|
|
num_patches = int((num_tokens - 1) ** 0.5) |
|
attn_map_reshaped = attn_map[:, layer, head, 0, 1:].reshape((1, num_patches, num_patches)) |
|
return attn_map_reshaped |
|
|
|
def visualize_attn_shift(attn_map1, attn_map2, image, display=True, out=None, min_diff=None, max_diff=None): |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
|
|
diff_map = attn_map2 - attn_map1 |
|
|
|
|
|
image = image.convert("RGBA") |
|
overlay = Image.new("RGBA", image.size, (0, 0, 0, 0)) |
|
draw = ImageDraw.Draw(overlay) |
|
|
|
|
|
block_size_x = image.size[0] / diff_map.shape[0] |
|
block_size_y = image.size[1] / diff_map.shape[1] |
|
|
|
|
|
cmap = plt.get_cmap('coolwarm_r') |
|
|
|
|
|
if max_diff is None: |
|
max_diff = diff_map.max() |
|
if min_diff is None: |
|
min_diff = diff_map.min() |
|
|
|
for i in range(diff_map.shape[0]): |
|
for j in range(diff_map.shape[1]): |
|
|
|
intensity = diff_map[i, j] |
|
normalized_intensity = (intensity - min_diff) / (max_diff - min_diff) |
|
rgba_color = cmap(1 - normalized_intensity) |
|
color = tuple(int(c * 255) for c in rgba_color[:3]) + (int(rgba_color[3] * 128),) |
|
|
|
|
|
draw.rectangle( |
|
[j * block_size_x, i * block_size_y, (j + 1) * block_size_x, (i + 1) * block_size_y], |
|
fill=color |
|
) |
|
|
|
|
|
combined = Image.alpha_composite(image, overlay) |
|
|
|
if display: |
|
|
|
combined.show() |
|
|
|
|
|
plt.figure(figsize=(6, 1)) |
|
plt.imshow([np.linspace(min_diff, max_diff, 256)], cmap='coolwarm_r', aspect='auto') |
|
plt.gca().set_visible(False) |
|
plt.colorbar(orientation="horizontal") |
|
plt.show() |
|
|
|
if out is not None: |
|
combined.save(out) |
|
|
|
return combined |
|
|
|
def visualize_attn_shift_binary(attn_map1, attn_map2, image, display=True, out=None): |
|
|
|
|
|
|
|
diff_map = attn_map2 - attn_map1 |
|
|
|
|
|
diff_map_normalized = (diff_map - diff_map.min()) / (diff_map.max() - diff_map.min()) |
|
|
|
image = image.convert("RGBA") |
|
overlay = Image.new("RGBA", image.size, (0, 0, 0, 0)) |
|
draw = ImageDraw.Draw(overlay) |
|
|
|
|
|
block_size_x = image.size[0] / diff_map.shape[0] |
|
block_size_y = image.size[1] / diff_map.shape[1] |
|
|
|
for i in range(diff_map.shape[0]): |
|
for j in range(diff_map.shape[1]): |
|
|
|
intensity = diff_map_normalized[i, j] |
|
alpha = int(255 * 0.5) |
|
if diff_map[i, j] > 0: |
|
color = (0, int(255 * intensity), 0, alpha) |
|
else: |
|
color = (int(255 * (1 - intensity)), 0, 0, alpha) |
|
|
|
|
|
draw.rectangle( |
|
[j * block_size_x, i * block_size_y, (j + 1) * block_size_x, (i + 1) * block_size_y], |
|
fill=color |
|
) |
|
|
|
|
|
combined = Image.alpha_composite(image, overlay) |
|
|
|
if display: |
|
|
|
combined.show() |
|
|
|
if out is not None: |
|
combined.save(out) |
|
|
|
return combined |
|
|
|
def is_outlier(mean, std, value): |
|
return value < mean - 2 * std or value > mean + 2 * std |
|
|
|
|
|
def get_neuron_activations(images, prs_group, model, device = "cuda:0"): |
|
|
|
random_neuron_acts = [] |
|
for image in tqdm(images, desc="Processing images"): |
|
prs_group.reinit() |
|
image_input = image.unsqueeze(0).to(device) |
|
representation = model.encode_image( |
|
image_input, attn_method="head", normalize=False |
|
) |
|
prs_group.finalize() |
|
gelu_outs = prs_group.post_gelu_outputs() |
|
random_neuron_acts.append(gelu_outs) |
|
random_neuron_acts = torch.stack(random_neuron_acts, dim = 0) |
|
return random_neuron_acts |
|
|
|
def normalize_array(arr): |
|
min_val = np.min(arr) |
|
max_val = np.max(arr) |
|
|
|
if max_val - min_val == 0: |
|
return np.zeros_like(arr) |
|
normalized_arr = (arr - min_val) / (max_val - min_val) |
|
return normalized_arr |
|
|
|
def np_l2(arr1, arr2): |
|
return np.linalg.norm(arr1 - arr2) |
|
|
|
def best_class(classifier, representation): |
|
cs = torch.cosine_similarity(classifier, representation.permute(1, 0), dim = 0) |
|
return torch.argmax(cs).item(), cs[torch.argmax(cs).item()].item() |
|
|
|
def load_group_attn_shifts(timestamp): |
|
|
|
results_dir = "./results/supp1B" |
|
|
|
|
|
|
|
|
|
latest_dir = os.path.join(results_dir, timestamp) |
|
print(f"Using latest results directory: {latest_dir}") |
|
|
|
|
|
with open(os.path.join(latest_dir, "metadata.json"), "r") as f: |
|
metadata = json.load(f) |
|
|
|
|
|
attn_maps = np.memmap(os.path.join(latest_dir, "attention_maps.mmap"), |
|
dtype=np.float32, |
|
mode='r', |
|
shape=tuple(metadata["attention_maps_shape"])) |
|
|
|
resblocks = np.memmap(os.path.join(latest_dir, "resblocks.mmap"), |
|
dtype=np.float32, |
|
mode='r', |
|
shape=tuple(metadata["resblocks_shape"])) |
|
|
|
|
|
file_list = metadata.get("file_list", []) |
|
|
|
|
|
top_k_values = metadata.get("top_k_values", [0]) |
|
|
|
return { |
|
"attn_maps": attn_maps, |
|
"resblocks": resblocks, |
|
"metadata": metadata, |
|
"file_list": file_list, |
|
"top_k_values": top_k_values, |
|
"num_layers": metadata.get("num_layers", 0), |
|
"num_images": metadata.get("num_images", 0), |
|
"num_heads": metadata.get("num_heads", 0) |
|
} |
|
|
|
def load_individual_attn_shifts(timestamp, supp = "supp1D"): |
|
results_dir = f"./results/{supp}" |
|
|
|
|
|
|
|
|
|
latest_dir = os.path.join(results_dir, timestamp) |
|
print(f"Using latest results directory: {latest_dir}") |
|
|
|
|
|
with open(os.path.join(latest_dir, "metadata.json"), "r") as f: |
|
metadata = json.load(f) |
|
|
|
|
|
attn_maps = np.memmap(os.path.join(latest_dir, "attention_maps.mmap"), |
|
dtype=np.float32, |
|
mode='r', |
|
shape=tuple(metadata["attention_maps_shape"])) |
|
|
|
baseline_attn_maps = np.memmap(os.path.join(latest_dir, "baseline_attention_maps.mmap"), |
|
dtype=np.float32, |
|
mode='r', |
|
shape=tuple(metadata["baseline_attention_maps_shape"])) |
|
|
|
neuron_activations = np.memmap(os.path.join(latest_dir, "neuron_activations.mmap"), |
|
dtype=np.float32, |
|
mode='r', |
|
shape=tuple(metadata["neuron_activations_shape"])) |
|
|
|
baseline_neuron_activations = np.memmap(os.path.join(latest_dir, "baseline_neuron_activations.mmap"), |
|
dtype=np.float32, |
|
mode='r', |
|
shape=tuple(metadata["baseline_neuron_activations_shape"])) |
|
|
|
ablated_neurons = np.memmap(os.path.join(latest_dir, "ablated_neurons.mmap"), |
|
dtype=np.float32, |
|
mode='r', |
|
shape=tuple(metadata["ablated_neurons_shape"])) |
|
|
|
|
|
file_list = metadata.get("file_list", []) |
|
|
|
|
|
k = metadata.get("k", 25) |
|
|
|
return { |
|
"attn_maps": attn_maps, |
|
"baseline_attn_maps": baseline_attn_maps, |
|
"neuron_activations": neuron_activations, |
|
"baseline_neuron_activations": baseline_neuron_activations, |
|
"ablated_neurons": ablated_neurons, |
|
"metadata": metadata, |
|
"file_list": file_list, |
|
"k": k, |
|
"num_layers": metadata.get("num_layers", 12), |
|
"num_images": metadata.get("num_images", 100), |
|
"model_name": metadata.get("model_name", "ViT-B-16"), |
|
"pretrained": metadata.get("pretrained", "openai") |
|
} |
|
|
|
def find_register_neurons_cuda(model, preprocess, prs_group, register_norm_threshold = 30, highest_layer = -1, device = "cuda:0", processed_image_cnt = 500): |
|
num_layers = len(model.visual.transformer.resblocks) |
|
highest_layer = num_layers - 1 if highest_layer == -1 else highest_layer |
|
num_neurons = model.visual.transformer.resblocks[0].mlp.state_dict()["c_proj.weight"].shape[1] |
|
random_images = load_images(preprocess, count=processed_image_cnt) |
|
neuron_scores = torch.zeros((len(random_images), num_layers, num_neurons), device=device) |
|
alignment_scores = torch.zeros((len(random_images), num_layers, num_neurons), device=device) |
|
image_count = 0 |
|
|
|
for i in tqdm(range(len(random_images)), desc="Processing random images"): |
|
image = random_images[i].unsqueeze(0).to(device) |
|
prs_group.reinit() |
|
|
|
with torch.inference_mode(): |
|
representation = model.encode_image( |
|
image, attn_method="head", normalize=False |
|
) |
|
prs_group.finalize() |
|
|
|
baseline_neuron_acts = prs_group.post_gelu_outputs().to(device) |
|
baseline_resblock_outputs = prs_group.resblock_outputs().to(device) |
|
|
|
|
|
norm_map = torch.norm(baseline_resblock_outputs[-1], dim=1) |
|
filtered_norms = norm_map.clone() |
|
filtered_norms[filtered_norms < register_norm_threshold] = 0 |
|
|
|
|
|
register_locations = torch.where(filtered_norms > register_norm_threshold)[0] |
|
|
|
if len(register_locations) == 0: |
|
continue |
|
|
|
image_count += 1 |
|
|
|
|
|
for layer in range(num_layers): |
|
|
|
act_layer = torch.abs(baseline_neuron_acts[layer]) |
|
|
|
|
|
sparse_neurons = torch.sum(act_layer < 0.5, dim=0) >= 0.5 * act_layer.shape[0] |
|
|
|
|
|
if not torch.any(sparse_neurons): |
|
continue |
|
|
|
|
|
|
|
register_values = act_layer[register_locations] |
|
|
|
|
|
|
|
neuron_means = register_values.mean(dim=0) |
|
|
|
|
|
neuron_means = neuron_means * sparse_neurons.float() |
|
|
|
|
|
neuron_scores[i, layer] = neuron_means |
|
|
|
|
|
mean_neuron_scores = neuron_scores[:image_count].mean(dim=0) |
|
mean_alignment_scores = alignment_scores[:image_count].mean(dim=0) |
|
|
|
|
|
flattened_scores = mean_neuron_scores.flatten() |
|
sorted_values, sorted_indices = torch.sort(flattened_scores, descending=True) |
|
|
|
flattened_alignment = mean_alignment_scores.flatten() |
|
sorted_alignment_values, sorted_alignment_indices = torch.sort(flattened_alignment, descending=True) |
|
|
|
|
|
top_indices = [(idx.item() // num_neurons, idx.item() % num_neurons) for idx in sorted_indices] |
|
top_alignment_indices = [(idx.item() // num_neurons, idx.item() % num_neurons) for idx in sorted_alignment_indices] |
|
|
|
register_norms = [ |
|
(layer, neuron, sorted_values[i].item()) |
|
for i, (layer, neuron) in enumerate(top_indices) |
|
if layer <= highest_layer |
|
] |
|
|
|
best_alignment_scores = [ |
|
(layer, neuron, sorted_alignment_values[i].item()) |
|
for i, (layer, neuron) in enumerate(top_alignment_indices) |
|
if layer <= highest_layer |
|
] |
|
|
|
return register_norms, best_alignment_scores |
|
|
|
def find_register_neurons(model, preprocess, prs_group, register_norm_threshold = 30, highest_layer = -1, device = "cuda:0", processed_image_cnt = 500): |
|
num_layers = len(model.visual.transformer.resblocks) |
|
highest_layer = num_layers - 1 if highest_layer == -1 else highest_layer |
|
num_neurons = model.visual.transformer.resblocks[0].mlp.state_dict()["c_proj.weight"].shape[1] |
|
|
|
random_images = load_images(preprocess, count = processed_image_cnt) |
|
neuron_scores = torch.zeros((len(random_images), num_layers, num_neurons)) |
|
for i in tqdm(range(len(random_images)), desc="Processing random images"): |
|
image = random_images[i].unsqueeze(0).to(device) |
|
|
|
prs_group.reinit() |
|
with torch.no_grad(): |
|
representation = model.encode_image( |
|
image, attn_method="head", normalize=False |
|
) |
|
prs_group.finalize() |
|
|
|
|
|
baseline_neuron_acts = prs_group.post_gelu_outputs().cpu().numpy() |
|
baseline_resblock_outputs = prs_group.resblock_outputs().cpu().numpy() |
|
|
|
|
|
norms = np.linalg.norm(baseline_resblock_outputs[-1], axis=1) |
|
norms[norms < register_norm_threshold] = 0 |
|
register_locations = np.where(norms > register_norm_threshold)[0] |
|
|
|
|
|
for layer in range(num_layers): |
|
for neuron in range(num_neurons): |
|
neuron_map = baseline_neuron_acts[layer, :, neuron] |
|
mask = np.zeros_like(neuron_map, dtype=bool) |
|
mask[register_locations] = True |
|
neuron_map[~mask] = 0 |
|
if np.any(neuron_map < 0): |
|
continue |
|
|
|
|
|
|
|
neuron_scores[i, layer, neuron] = torch.tensor(neuron_map[register_locations].mean()) |
|
mean_neuron_scores = neuron_scores.mean(dim=0) |
|
|
|
flattened_scores = mean_neuron_scores.flatten() |
|
sorted_values, sorted_indices = torch.sort(flattened_scores, descending=True) |
|
|
|
|
|
top_indices = [(idx.item() // num_neurons, idx.item() % num_neurons) for idx in sorted_indices] |
|
|
|
return [(layer, neuron, sorted_values[i].item()) for i, (layer, neuron) in enumerate(top_indices) if layer <= highest_layer] |
|
|
|
|
|
def plot_attn_maps(attn_maps, image_idx): |
|
|
|
num_layers, num_heads, patch_height, patch_width = attn_maps.shape |
|
print(f"Shape of image_shifts: {attn_maps.shape}") |
|
|
|
|
|
fig, axes = plt.subplots(num_layers, num_heads, figsize=(2*num_heads, 2*num_layers)) |
|
fig.suptitle(f'Attention Shift Maps for Image #{image_idx}', fontsize=16) |
|
|
|
|
|
from mpl_toolkits.axes_grid1 import make_axes_locatable |
|
|
|
|
|
for layer in range(num_layers): |
|
|
|
layer_vmin = attn_maps[layer].min().item() |
|
layer_vmax = attn_maps[layer].max().item() |
|
|
|
for head in range(num_heads): |
|
|
|
if num_layers == 1 and num_heads == 1: |
|
ax = axes |
|
elif num_layers == 1: |
|
ax = axes[head] |
|
elif num_heads == 1: |
|
ax = axes[layer] |
|
else: |
|
ax = axes[layer, head] |
|
|
|
|
|
im = ax.imshow(attn_maps[layer, head], cmap='viridis', vmin=layer_vmin, vmax=layer_vmax) |
|
|
|
|
|
ax.set_xticks([]) |
|
ax.set_yticks([]) |
|
|
|
|
|
if head == 0: |
|
ax.set_ylabel(f'Layer {layer}') |
|
if layer == num_layers-1: |
|
ax.set_xlabel(f'Head {head}') |
|
|
|
|
|
if head == num_heads-1: |
|
|
|
divider = make_axes_locatable(ax) |
|
cax = divider.append_axes("right", size="5%", pad=0.05) |
|
plt.colorbar(im, cax=cax) |
|
|
|
|
|
plt.tight_layout() |
|
return plt |
|
|
|
def calculate_iou(output, target): |
|
intersection = output * (output == target) |
|
area_inter = intersection.sum().item() |
|
area_pred = output.sum().item() |
|
area_target = target.sum().item() |
|
union = area_pred + area_target - area_inter |
|
iou = area_inter / union |
|
return area_inter, union, iou |
|
|
|
def calculate_pixel_accuracy(output, target): |
|
correct = output * (output == target) |
|
correct = correct.sum().item() |
|
total = target.sum().item() |
|
return correct, total, correct / total |
|
|