edit source
Browse files
model.py
CHANGED
@@ -1,31 +1,7 @@
|
|
1 |
-
###################################################################################################
|
2 |
-
###################################################################################################
|
3 |
-
###################################################################################################
|
4 |
-
|
5 |
-
import collections
|
6 |
-
import logging
|
7 |
-
|
8 |
-
import json
|
9 |
-
import math
|
10 |
-
import os
|
11 |
-
import re
|
12 |
-
from collections import OrderedDict
|
13 |
-
from functools import partial
|
14 |
-
from typing import List, Optional, Tuple, Union
|
15 |
-
|
16 |
-
import torch
|
17 |
-
import torch.nn as nn
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
########################################################
|
22 |
-
########################################################
|
23 |
-
########################################################
|
24 |
-
########################################################
|
25 |
-
|
26 |
-
|
27 |
from typing import Callable, Optional, Tuple
|
|
|
28 |
import copy
|
|
|
29 |
import math
|
30 |
import multiprocessing
|
31 |
import os
|
@@ -34,7 +10,6 @@ import torch
|
|
34 |
import torch.nn as nn
|
35 |
import transformers
|
36 |
|
37 |
-
|
38 |
class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig):
|
39 |
"""We create a dummy configuration class that will just set properties
|
40 |
based on whatever kwargs we pass in.
|
@@ -54,14 +29,13 @@ class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig):
|
|
54 |
continue
|
55 |
super().__init__()
|
56 |
|
57 |
-
|
58 |
def load_embedder_and_tokenizer(name: str) -> Tuple[
|
59 |
transformers.PreTrainedModel,
|
60 |
transformers.PreTrainedTokenizer
|
61 |
]:
|
62 |
-
|
63 |
if name.startswith("nomic") or (name == "bert-base-uncased"):
|
64 |
-
model =
|
65 |
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
66 |
elif name in ["gtr-base", "gtr_base"]:
|
67 |
model = transformers.AutoModel.from_pretrained(
|
@@ -106,8 +80,6 @@ def load_embedder_and_tokenizer(name: str) -> Tuple[
|
|
106 |
# from optimum.bettertransformer import BetterTransformer
|
107 |
# model = BetterTransformer.transform(model)
|
108 |
return model, tokenizer
|
109 |
-
|
110 |
-
|
111 |
def get_world_size() -> int:
|
112 |
try:
|
113 |
return torch.distributed.get_world_size()
|
@@ -318,7 +290,7 @@ def maxsim(
|
|
318 |
sub_x = slice_tensor_rows(X, start, end)
|
319 |
if debug_mem_usage: print(f"[maxsim] step {i} cuda mem free/total = {torch.cuda.mem_get_info()}")
|
320 |
if debug_mem_usage: print("[maxsim] sub_x.shape:", sub_x.shape, "//", "y.shape:", y.shape)
|
321 |
-
sub_sim = sub_x @ y # TODO –
|
322 |
sub_sim = sub_sim
|
323 |
if maximize:
|
324 |
sub_max_sim_v, sub_max_sim_i = sub_sim.to_dense().max(dim=-1)
|
@@ -471,7 +443,6 @@ def disable_causality(model: torch.nn.Module):
|
|
471 |
f"Set is_causal=False in {disabled_modules} modules from model type {type(model)}"
|
472 |
)
|
473 |
|
474 |
-
|
475 |
class ContextualModelMixin(nn.Module):
|
476 |
@property
|
477 |
def num_corpus_tokens(self) -> int:
|
@@ -511,9 +482,6 @@ class ContextualModelMixin(nn.Module):
|
|
511 |
# Auto-expand for a batch.
|
512 |
dataset_embeddings = dataset_embeddings[None, :, :] # (b, d) -> (1, b, d)
|
513 |
dataset_embeddings = dataset_embeddings.to(input_ids.device)
|
514 |
-
|
515 |
-
if len(dataset_embeddings.shape) < 3:
|
516 |
-
raise ValueError(f"dataset_embeddings must have at least 3 dimensions, got {dataset_embeddings.shape}")
|
517 |
|
518 |
batch_size = input_ids.shape[0]
|
519 |
if (self.transductive_tokens_per_document > 1):
|
@@ -532,11 +500,9 @@ class ContextualModelMixin(nn.Module):
|
|
532 |
dataset_embeddings = dataset_embeddings[R].reshape((batch_size, self.num_corpus_tokens, self.hidden_size))
|
533 |
else:
|
534 |
dataset_embeddings = dataset_embeddings.reshape((1, self.num_corpus_tokens, self.hidden_size))
|
|
|
535 |
|
536 |
-
|
537 |
-
if dataset_embeddings.shape[1] < self.num_corpus_tokens:
|
538 |
-
raise ValueError(f"dataset_embeddings must have at least {self.num_corpus_tokens} tokens, got {dataset_embeddings.shape[1]}")
|
539 |
-
elif dataset_embeddings.shape[1] > self.num_corpus_tokens:
|
540 |
# If too many dataset embeddings are passed in, just take the first N until
|
541 |
# we have the proper number.
|
542 |
dataset_embeddings = dataset_embeddings[:, :self.num_corpus_tokens, :]
|
@@ -558,6 +524,8 @@ class ContextualModelMixin(nn.Module):
|
|
558 |
null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1)
|
559 |
dataset_embeddings = null_embeddings
|
560 |
|
|
|
|
|
561 |
# backbone_max_seq_length = self.backbone.config.max_trained_positions
|
562 |
# assert batch_size + (2 * self.n_soft_prompt + corpus_size) <= backbone_max_seq_length, "too many hard negatives for backbone model"
|
563 |
soft_prompt = torch.ones((1, self.hidden_size), device=dataset_embeddings.device, dtype=dataset_embeddings.dtype)
|
@@ -630,8 +598,15 @@ class BiEncoder(transformers.PreTrainedModel):
|
|
630 |
[d1, d2, d3, hn1_1, hn1_2, hn2_1, hn2_2, hn3_1, hn3_2]
|
631 |
for a corpus with three documents and two hard negatives per document
|
632 |
"""
|
|
|
|
|
633 |
del token_type_ids
|
634 |
|
|
|
|
|
|
|
|
|
|
|
635 |
outputs = (
|
636 |
self.embedder(
|
637 |
input_ids=input_ids,
|
@@ -801,7 +776,6 @@ class DatasetConditionedAutoregressive(transformers.PreTrainedModel, ContextualM
|
|
801 |
return output
|
802 |
|
803 |
|
804 |
-
|
805 |
class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelMixin):
|
806 |
def __init__(
|
807 |
self,
|
@@ -812,14 +786,12 @@ class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelM
|
|
812 |
self.backbone = dataset_backbone
|
813 |
self.hidden_size = self.backbone.config.hidden_size
|
814 |
self.hidden_size = dataset_backbone.config.hidden_size
|
|
|
|
|
|
|
|
|
815 |
self.contextual_init()
|
816 |
self._shift_rotary_embedding()
|
817 |
-
|
818 |
-
self.pool_ignore_contextual_tokens = vars(self.config).get("pool_ignore_contextual_tokens", True)
|
819 |
-
self.pool_ignore_instruction_tokens = vars(self.config).get("pool_ignore_instruction_tokens", False)
|
820 |
-
|
821 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained(self.config.embedder)
|
822 |
-
self.pool_instruction_end_id = tokenizer.encode(": ", add_special_tokens=False)[0] # Hardcoded for colon-ending prefixes.
|
823 |
|
824 |
@property
|
825 |
def num_corpus_tokens(self) -> int:
|
@@ -848,55 +820,48 @@ class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelM
|
|
848 |
output_hidden_states: bool = False,
|
849 |
null_dataset_embedding: bool = False,
|
850 |
) -> torch.Tensor:
|
|
|
851 |
soft_prompt = self._prepare_dataset_embeddings(
|
852 |
input_ids=input_ids,
|
853 |
dataset_embeddings=dataset_embeddings,
|
854 |
null_dataset_embedding=null_dataset_embedding,
|
855 |
)
|
|
|
856 |
backbone_attention_mask = torch.ones(
|
857 |
soft_prompt.shape[0:2],
|
858 |
dtype=torch.long,
|
859 |
device=soft_prompt.device,
|
860 |
)
|
861 |
inputs_embeds = self.backbone.embeddings(input_ids) # (b, s) -> (b, s, d)
|
|
|
862 |
inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d)
|
863 |
-
|
|
|
|
|
864 |
output = self.backbone(
|
865 |
inputs_embeds=inputs_embeds,
|
866 |
-
attention_mask=
|
867 |
) # (1, 4 + b + s, d)
|
868 |
# trim soft prompt
|
869 |
output_vectors = output.last_hidden_state
|
870 |
|
871 |
# use only these tokens
|
872 |
n_soft_prompt_tokens = soft_prompt.shape[1]
|
|
|
873 |
|
874 |
-
|
875 |
-
|
876 |
-
# This is a bit arcane but relies on the fact that there will be a BOS token after the
|
877 |
-
# instruction, but also there may or may not be a BOS token at the beginning.
|
878 |
-
instruction_end_idx = (
|
879 |
-
(input_ids == self.pool_instruction_end_id) &
|
880 |
-
attention_mask &
|
881 |
-
(torch.arange(input_ids.shape[1], device=input_ids.device)[None, :] > 0)
|
882 |
-
).int().argmax(1)
|
883 |
-
is_instruction_token_mask = (
|
884 |
-
torch.arange(input_ids.shape[1], device=input_ids.device)[None, :] <= instruction_end_idx[:, None]
|
885 |
-
)
|
886 |
-
# catch edge case where there is no instruction
|
887 |
-
is_instruction_token_mask = is_instruction_token_mask.where(
|
888 |
-
(instruction_end_idx > 0)[:, None], torch.zeros_like(is_instruction_token_mask)
|
889 |
-
)
|
890 |
-
output_attention_mask = torch.cat((backbone_attention_mask, attention_mask & ~is_instruction_token_mask), dim=1)
|
891 |
-
else:
|
892 |
-
output_attention_mask = input_attention_mask
|
893 |
|
894 |
-
|
895 |
-
output_vectors = output_vectors[:, n_soft_prompt_tokens:, :]
|
896 |
-
output_attention_mask = output_attention_mask[:, n_soft_prompt_tokens:]
|
897 |
output_pooled = mean_pool(output_vectors, output_attention_mask)
|
|
|
898 |
# average with original vectors
|
899 |
-
|
|
|
|
|
|
|
|
|
|
|
900 |
|
901 |
if output_hidden_states:
|
902 |
return {
|
@@ -967,7 +932,7 @@ class ContextualDocumentEmbeddingTransformer(transformers.PreTrainedModel):
|
|
967 |
):
|
968 |
super().__init__(config=config)
|
969 |
dataset_backbone, _ = load_embedder_and_tokenizer(
|
970 |
-
vars(config).get("dataset_backbone"
|
971 |
)
|
972 |
|
973 |
if config.limit_layers:
|
@@ -1026,6 +991,8 @@ class ContextualDocumentEmbeddingTransformer(transformers.PreTrainedModel):
|
|
1026 |
output_hidden_states=output_hidden_states,
|
1027 |
)
|
1028 |
|
|
|
|
|
1029 |
def get_model_class(name: str):
|
1030 |
if name in 'transductive':
|
1031 |
return ContextualDocumentEmbeddingTransformer
|
@@ -1034,4 +1001,4 @@ def get_model_class(name: str):
|
|
1034 |
elif name == "dataset_prefix_biencoder":
|
1035 |
return DatasetPrefixBiencoder
|
1036 |
else:
|
1037 |
-
raise ValueError(f'unknown model cls {name}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from typing import Callable, Optional, Tuple
|
2 |
+
|
3 |
import copy
|
4 |
+
import json
|
5 |
import math
|
6 |
import multiprocessing
|
7 |
import os
|
|
|
10 |
import torch.nn as nn
|
11 |
import transformers
|
12 |
|
|
|
13 |
class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig):
|
14 |
"""We create a dummy configuration class that will just set properties
|
15 |
based on whatever kwargs we pass in.
|
|
|
29 |
continue
|
30 |
super().__init__()
|
31 |
|
|
|
32 |
def load_embedder_and_tokenizer(name: str) -> Tuple[
|
33 |
transformers.PreTrainedModel,
|
34 |
transformers.PreTrainedTokenizer
|
35 |
]:
|
36 |
+
assert name is not None, "name must be provided to load_embedder_and_tokenizer"
|
37 |
if name.startswith("nomic") or (name == "bert-base-uncased"):
|
38 |
+
model = transformers.AutoModelForMaskedLM.from_pretrained(name, trust_remote_code=True).bert
|
39 |
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
40 |
elif name in ["gtr-base", "gtr_base"]:
|
41 |
model = transformers.AutoModel.from_pretrained(
|
|
|
80 |
# from optimum.bettertransformer import BetterTransformer
|
81 |
# model = BetterTransformer.transform(model)
|
82 |
return model, tokenizer
|
|
|
|
|
83 |
def get_world_size() -> int:
|
84 |
try:
|
85 |
return torch.distributed.get_world_size()
|
|
|
290 |
sub_x = slice_tensor_rows(X, start, end)
|
291 |
if debug_mem_usage: print(f"[maxsim] step {i} cuda mem free/total = {torch.cuda.mem_get_info()}")
|
292 |
if debug_mem_usage: print("[maxsim] sub_x.shape:", sub_x.shape, "//", "y.shape:", y.shape)
|
293 |
+
sub_sim = sub_x @ y # TODO – Implement sparse max here to save mem!
|
294 |
sub_sim = sub_sim
|
295 |
if maximize:
|
296 |
sub_max_sim_v, sub_max_sim_i = sub_sim.to_dense().max(dim=-1)
|
|
|
443 |
f"Set is_causal=False in {disabled_modules} modules from model type {type(model)}"
|
444 |
)
|
445 |
|
|
|
446 |
class ContextualModelMixin(nn.Module):
|
447 |
@property
|
448 |
def num_corpus_tokens(self) -> int:
|
|
|
482 |
# Auto-expand for a batch.
|
483 |
dataset_embeddings = dataset_embeddings[None, :, :] # (b, d) -> (1, b, d)
|
484 |
dataset_embeddings = dataset_embeddings.to(input_ids.device)
|
|
|
|
|
|
|
485 |
|
486 |
batch_size = input_ids.shape[0]
|
487 |
if (self.transductive_tokens_per_document > 1):
|
|
|
500 |
dataset_embeddings = dataset_embeddings[R].reshape((batch_size, self.num_corpus_tokens, self.hidden_size))
|
501 |
else:
|
502 |
dataset_embeddings = dataset_embeddings.reshape((1, self.num_corpus_tokens, self.hidden_size))
|
503 |
+
# print("reshaped to dataset_embeddings.shape =", dataset_embeddings.shape)
|
504 |
|
505 |
+
if dataset_embeddings.shape[1] > self.num_corpus_tokens:
|
|
|
|
|
|
|
506 |
# If too many dataset embeddings are passed in, just take the first N until
|
507 |
# we have the proper number.
|
508 |
dataset_embeddings = dataset_embeddings[:, :self.num_corpus_tokens, :]
|
|
|
524 |
null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1)
|
525 |
dataset_embeddings = null_embeddings
|
526 |
|
527 |
+
# print(f"[ContextualModelMixin] dataset_embeddings.shape = {dataset_embeddings.shape}")
|
528 |
+
|
529 |
# backbone_max_seq_length = self.backbone.config.max_trained_positions
|
530 |
# assert batch_size + (2 * self.n_soft_prompt + corpus_size) <= backbone_max_seq_length, "too many hard negatives for backbone model"
|
531 |
soft_prompt = torch.ones((1, self.hidden_size), device=dataset_embeddings.device, dtype=dataset_embeddings.dtype)
|
|
|
598 |
[d1, d2, d3, hn1_1, hn1_2, hn2_1, hn2_2, hn3_1, hn3_2]
|
599 |
for a corpus with three documents and two hard negatives per document
|
600 |
"""
|
601 |
+
# del dataset_input_ids
|
602 |
+
# del dataset_attention_mask
|
603 |
del token_type_ids
|
604 |
|
605 |
+
# from cde.lib.dist import get_rank
|
606 |
+
# tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")
|
607 |
+
# if get_rank() == 0:
|
608 |
+
# breakpoint()
|
609 |
+
# torch.distributed.barrier()
|
610 |
outputs = (
|
611 |
self.embedder(
|
612 |
input_ids=input_ids,
|
|
|
776 |
return output
|
777 |
|
778 |
|
|
|
779 |
class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelMixin):
|
780 |
def __init__(
|
781 |
self,
|
|
|
786 |
self.backbone = dataset_backbone
|
787 |
self.hidden_size = self.backbone.config.hidden_size
|
788 |
self.hidden_size = dataset_backbone.config.hidden_size
|
789 |
+
# self.input_ln = torch.nn.LayerNorm(
|
790 |
+
# self.hidden_size,
|
791 |
+
# eps=self.backbone.config.layer_norm_epsilon
|
792 |
+
# )
|
793 |
self.contextual_init()
|
794 |
self._shift_rotary_embedding()
|
|
|
|
|
|
|
|
|
|
|
|
|
795 |
|
796 |
@property
|
797 |
def num_corpus_tokens(self) -> int:
|
|
|
820 |
output_hidden_states: bool = False,
|
821 |
null_dataset_embedding: bool = False,
|
822 |
) -> torch.Tensor:
|
823 |
+
# print(f"[DatasetConditionedBiencoder - 0] input_ids.shape => {input_ids.shape} // dataset_embeddings.shape =", dataset_embeddings.shape)
|
824 |
soft_prompt = self._prepare_dataset_embeddings(
|
825 |
input_ids=input_ids,
|
826 |
dataset_embeddings=dataset_embeddings,
|
827 |
null_dataset_embedding=null_dataset_embedding,
|
828 |
)
|
829 |
+
# print(f"[DatasetConditionedBiencoder - 1] soft_prompt.shape => {soft_prompt.shape}")
|
830 |
backbone_attention_mask = torch.ones(
|
831 |
soft_prompt.shape[0:2],
|
832 |
dtype=torch.long,
|
833 |
device=soft_prompt.device,
|
834 |
)
|
835 |
inputs_embeds = self.backbone.embeddings(input_ids) # (b, s) -> (b, s, d)
|
836 |
+
# print("[2] inputs_embeds.shape =", inputs_embeds.shape)
|
837 |
inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d)
|
838 |
+
# print("[3.a] inputs_embeds.shape =", inputs_embeds.shape)
|
839 |
+
attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1)
|
840 |
+
# print("[3.b] attention_mask.shape =", attention_mask.shape)
|
841 |
output = self.backbone(
|
842 |
inputs_embeds=inputs_embeds,
|
843 |
+
attention_mask=attention_mask,
|
844 |
) # (1, 4 + b + s, d)
|
845 |
# trim soft prompt
|
846 |
output_vectors = output.last_hidden_state
|
847 |
|
848 |
# use only these tokens
|
849 |
n_soft_prompt_tokens = soft_prompt.shape[1]
|
850 |
+
# print("n_soft_prompt_tokens =", n_soft_prompt_tokens)
|
851 |
|
852 |
+
output_vectors = output.last_hidden_state[:, n_soft_prompt_tokens:, :]
|
853 |
+
output_attention_mask = attention_mask[:, n_soft_prompt_tokens:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
854 |
|
855 |
+
# print("pooling output_vectors.shape =", output_vectors.shape, "and output_attention_mask.shape =", output_attention_mask.shape)
|
|
|
|
|
856 |
output_pooled = mean_pool(output_vectors, output_attention_mask)
|
857 |
+
|
858 |
# average with original vectors
|
859 |
+
# TODO: Argparse for pooling strategy.
|
860 |
+
# output_vectors = torch.cat((soft_prompt_pooled, output_pooled), dim=1) # (b, d) + (b, d) -> (b, 2d)
|
861 |
+
# print("output_pooled.shape =", output_pooled.shape)
|
862 |
+
output = self.output_projection(output_pooled) # (b, 2d) -> (b, d)
|
863 |
+
|
864 |
+
# print("returning output.shape =", output.shape)
|
865 |
|
866 |
if output_hidden_states:
|
867 |
return {
|
|
|
932 |
):
|
933 |
super().__init__(config=config)
|
934 |
dataset_backbone, _ = load_embedder_and_tokenizer(
|
935 |
+
vars(config).get("dataset_backbone") or config.embedder
|
936 |
)
|
937 |
|
938 |
if config.limit_layers:
|
|
|
991 |
output_hidden_states=output_hidden_states,
|
992 |
)
|
993 |
|
994 |
+
|
995 |
+
|
996 |
def get_model_class(name: str):
|
997 |
if name in 'transductive':
|
998 |
return ContextualDocumentEmbeddingTransformer
|
|
|
1001 |
elif name == "dataset_prefix_biencoder":
|
1002 |
return DatasetPrefixBiencoder
|
1003 |
else:
|
1004 |
+
raise ValueError(f'unknown model cls {name}')
|