bert-base-uncased-sent-debias-gender2 / modeling_sentence_debias.py
FilipT's picture
Add gender-debiased bert-base-uncased (Sentence-Debias, standalone)
b670ab9 verified
"""Standalone Sentence-Debias wrapper – autogenerated."""
import os, torch, transformers
from functools import partial
def _debias_hook(b_dir, module, inputs, output):
x = output.last_hidden_state if hasattr(output, "last_hidden_state") else output[0]
b = b_dir.to(x.device)
proj = torch.matmul(x, b) / torch.dot(b, b)
debiased = x - proj.unsqueeze(-1) * b
if hasattr(output, "last_hidden_state"):
output.last_hidden_state = debiased
return output
return (debiased,) + output[1:]
class SentenceDebiasBertForMaskedLM(transformers.BertForMaskedLM):
"""bert-base-uncased with a gender Sentence-Debias projection."""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
# 1. load base model normally
model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
# 2. find bias vector
bias_path = kwargs.pop("bias_direction_path", None)
if bias_path is None:
bias_path = os.path.join(pretrained_model_name_or_path, "bias_direction_gender.pt")
bias_vec = torch.load(bias_path, map_location="cpu")
# 3. register the debiasing hook on the encoder
model.bert.register_forward_hook(partial(_debias_hook, bias_vec))
model.register_buffer("bias_direction", bias_vec)
return model