|
|
|
|
|
"""Standalone Sentence-Debias wrapper – autogenerated.""" |
|
|
import os, torch, transformers |
|
|
from functools import partial |
|
|
|
|
|
def _debias_hook(b_dir, module, inputs, output): |
|
|
x = output.last_hidden_state if hasattr(output, "last_hidden_state") else output[0] |
|
|
b = b_dir.to(x.device) |
|
|
proj = torch.matmul(x, b) / torch.dot(b, b) |
|
|
debiased = x - proj.unsqueeze(-1) * b |
|
|
if hasattr(output, "last_hidden_state"): |
|
|
output.last_hidden_state = debiased |
|
|
return output |
|
|
return (debiased,) + output[1:] |
|
|
|
|
|
class SentenceDebiasBertForMaskedLM(transformers.BertForMaskedLM): |
|
|
"""bert-base-uncased with a gender Sentence-Debias projection.""" |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
|
|
|
|
|
model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) |
|
|
|
|
|
|
|
|
bias_path = kwargs.pop("bias_direction_path", None) |
|
|
if bias_path is None: |
|
|
bias_path = os.path.join(pretrained_model_name_or_path, "bias_direction_gender.pt") |
|
|
bias_vec = torch.load(bias_path, map_location="cpu") |
|
|
|
|
|
|
|
|
model.bert.register_forward_hook(partial(_debias_hook, bias_vec)) |
|
|
model.register_buffer("bias_direction", bias_vec) |
|
|
return model |
|
|
|