momergul commited on
Commit
cbc8d5f
·
verified ·
1 Parent(s): 908e39c

Upload utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. utils.py +37 -0
utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def exists(val):
4
+ return val is not None
5
+
6
+ # for controlling freezing during training of flamingo
7
+
8
+ def set_module_requires_grad_(module, requires_grad):
9
+ for param in module.parameters():
10
+ param.requires_grad = requires_grad
11
+
12
+ def freeze_all_layers_(module):
13
+ set_module_requires_grad_(module, False)
14
+
15
+ def unfreeze_all_layers_(module):
16
+ set_module_requires_grad_(module, True)
17
+
18
+ def freeze_model_and_make_eval_(model):
19
+ model.eval()
20
+ freeze_all_layers_(model)
21
+
22
+ def _make_att_wd_mask(
23
+ input_ids_shape: torch.Size,
24
+ dtype: torch.dtype, device: torch.device,
25
+ past_key_values_length: int = 0,
26
+ att_wd_size: int = 0,
27
+ ):
28
+ bsz, tgt_len = input_ids_shape
29
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
30
+ mask_cond = torch.arange(mask.size(-1), device=device)
31
+ mask.masked_fill_(
32
+ mask_cond > (mask_cond - att_wd_size).view(mask.size(-1), 1), 0)
33
+ mask = mask.to(dtype)
34
+
35
+ if past_key_values_length > 0:
36
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
37
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)