|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
|
|
base_dir = './' |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
for name in ['baseline', 'gate_elementwise', 'gate_headwise']: |
|
|
|
|
|
model_name_or_path = f"{base_dir}/1B_{name}" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True).to(device) |
|
|
|
|
|
prompt = "Sparse gating mechanism mitigates attention sink." |
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model( |
|
input_ids=inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
output_attentions=True |
|
) |
|
|
|
|
|
attentions = outputs.attentions |
|
|
|
|
|
def average_heads(attentions): |
|
averaged = [] |
|
for layer_attn in attentions: |
|
|
|
avg_attn = layer_attn.mean(dim=1).cpu().numpy() |
|
averaged.append(avg_attn[0]) |
|
return averaged |
|
|
|
averaged_attentions = average_heads(attentions) |
|
|
|
|
|
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) |
|
|
|
|
|
layers_to_visualize = [0, 6, 20, 27] |
|
fig, axes = plt.subplots(2, 2, figsize=(14, 12)) |
|
axes = axes.flatten() |
|
|
|
for idx, layer_idx in enumerate(layers_to_visualize): |
|
attn_map = averaged_attentions[layer_idx] |
|
|
|
|
|
ax = axes[idx] |
|
im = ax.imshow(attn_map, cmap="viridis") |
|
|
|
|
|
fig.colorbar(im, ax=ax) |
|
|
|
|
|
ax.set_title(f"Layer {layer_idx + 1}") |
|
|
|
|
|
ax.set_xticks(np.arange(len(tokens))) |
|
ax.set_yticks(np.arange(len(tokens))) |
|
ax.set_xticklabels(tokens, rotation=90) |
|
ax.set_yticklabels(tokens) |
|
|
|
|
|
ax.tick_params(axis='both', which='both', length=0) |
|
|
|
plt.tight_layout() |
|
plt.savefig(f"{name}_selected_layer_attention_maps.png") |
|
plt.show() |