Upload utils.py with huggingface_hub
Browse files
utils.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def exists(val):
|
4 |
+
return val is not None
|
5 |
+
|
6 |
+
# for controlling freezing during training of flamingo
|
7 |
+
|
8 |
+
def set_module_requires_grad_(module, requires_grad):
|
9 |
+
for param in module.parameters():
|
10 |
+
param.requires_grad = requires_grad
|
11 |
+
|
12 |
+
def freeze_all_layers_(module):
|
13 |
+
set_module_requires_grad_(module, False)
|
14 |
+
|
15 |
+
def unfreeze_all_layers_(module):
|
16 |
+
set_module_requires_grad_(module, True)
|
17 |
+
|
18 |
+
def freeze_model_and_make_eval_(model):
|
19 |
+
model.eval()
|
20 |
+
freeze_all_layers_(model)
|
21 |
+
|
22 |
+
def _make_att_wd_mask(
|
23 |
+
input_ids_shape: torch.Size,
|
24 |
+
dtype: torch.dtype, device: torch.device,
|
25 |
+
past_key_values_length: int = 0,
|
26 |
+
att_wd_size: int = 0,
|
27 |
+
):
|
28 |
+
bsz, tgt_len = input_ids_shape
|
29 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
30 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
31 |
+
mask.masked_fill_(
|
32 |
+
mask_cond > (mask_cond - att_wd_size).view(mask.size(-1), 1), 0)
|
33 |
+
mask = mask.to(dtype)
|
34 |
+
|
35 |
+
if past_key_values_length > 0:
|
36 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
37 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|