jxm commited on
Commit
cba0a1d
·
verified ·
1 Parent(s): 629557b

edit source

Browse files
Files changed (1) hide show
  1. model.py +42 -75
model.py CHANGED
@@ -1,31 +1,7 @@
1
- ###################################################################################################
2
- ###################################################################################################
3
- ###################################################################################################
4
-
5
- import collections
6
- import logging
7
-
8
- import json
9
- import math
10
- import os
11
- import re
12
- from collections import OrderedDict
13
- from functools import partial
14
- from typing import List, Optional, Tuple, Union
15
-
16
- import torch
17
- import torch.nn as nn
18
-
19
-
20
-
21
- ########################################################
22
- ########################################################
23
- ########################################################
24
- ########################################################
25
-
26
-
27
  from typing import Callable, Optional, Tuple
 
28
  import copy
 
29
  import math
30
  import multiprocessing
31
  import os
@@ -34,7 +10,6 @@ import torch
34
  import torch.nn as nn
35
  import transformers
36
 
37
-
38
  class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig):
39
  """We create a dummy configuration class that will just set properties
40
  based on whatever kwargs we pass in.
@@ -54,14 +29,13 @@ class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig):
54
  continue
55
  super().__init__()
56
 
57
-
58
  def load_embedder_and_tokenizer(name: str) -> Tuple[
59
  transformers.PreTrainedModel,
60
  transformers.PreTrainedTokenizer
61
  ]:
62
- print("Loading model:", name)
63
  if name.startswith("nomic") or (name == "bert-base-uncased"):
64
- model = ContextualNomicBertForPreTraining.from_pretrained(name, trust_remote_code=True).bert
65
  tokenizer = transformers.AutoTokenizer.from_pretrained(name)
66
  elif name in ["gtr-base", "gtr_base"]:
67
  model = transformers.AutoModel.from_pretrained(
@@ -106,8 +80,6 @@ def load_embedder_and_tokenizer(name: str) -> Tuple[
106
  # from optimum.bettertransformer import BetterTransformer
107
  # model = BetterTransformer.transform(model)
108
  return model, tokenizer
109
-
110
-
111
  def get_world_size() -> int:
112
  try:
113
  return torch.distributed.get_world_size()
@@ -318,7 +290,7 @@ def maxsim(
318
  sub_x = slice_tensor_rows(X, start, end)
319
  if debug_mem_usage: print(f"[maxsim] step {i} cuda mem free/total = {torch.cuda.mem_get_info()}")
320
  if debug_mem_usage: print("[maxsim] sub_x.shape:", sub_x.shape, "//", "y.shape:", y.shape)
321
- sub_sim = sub_x @ y # TODO – Implement sparse max here to save mem!
322
  sub_sim = sub_sim
323
  if maximize:
324
  sub_max_sim_v, sub_max_sim_i = sub_sim.to_dense().max(dim=-1)
@@ -471,7 +443,6 @@ def disable_causality(model: torch.nn.Module):
471
  f"Set is_causal=False in {disabled_modules} modules from model type {type(model)}"
472
  )
473
 
474
-
475
  class ContextualModelMixin(nn.Module):
476
  @property
477
  def num_corpus_tokens(self) -> int:
@@ -511,9 +482,6 @@ class ContextualModelMixin(nn.Module):
511
  # Auto-expand for a batch.
512
  dataset_embeddings = dataset_embeddings[None, :, :] # (b, d) -> (1, b, d)
513
  dataset_embeddings = dataset_embeddings.to(input_ids.device)
514
-
515
- if len(dataset_embeddings.shape) < 3:
516
- raise ValueError(f"dataset_embeddings must have at least 3 dimensions, got {dataset_embeddings.shape}")
517
 
518
  batch_size = input_ids.shape[0]
519
  if (self.transductive_tokens_per_document > 1):
@@ -532,11 +500,9 @@ class ContextualModelMixin(nn.Module):
532
  dataset_embeddings = dataset_embeddings[R].reshape((batch_size, self.num_corpus_tokens, self.hidden_size))
533
  else:
534
  dataset_embeddings = dataset_embeddings.reshape((1, self.num_corpus_tokens, self.hidden_size))
 
535
 
536
-
537
- if dataset_embeddings.shape[1] < self.num_corpus_tokens:
538
- raise ValueError(f"dataset_embeddings must have at least {self.num_corpus_tokens} tokens, got {dataset_embeddings.shape[1]}")
539
- elif dataset_embeddings.shape[1] > self.num_corpus_tokens:
540
  # If too many dataset embeddings are passed in, just take the first N until
541
  # we have the proper number.
542
  dataset_embeddings = dataset_embeddings[:, :self.num_corpus_tokens, :]
@@ -558,6 +524,8 @@ class ContextualModelMixin(nn.Module):
558
  null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1)
559
  dataset_embeddings = null_embeddings
560
 
 
 
561
  # backbone_max_seq_length = self.backbone.config.max_trained_positions
562
  # assert batch_size + (2 * self.n_soft_prompt + corpus_size) <= backbone_max_seq_length, "too many hard negatives for backbone model"
563
  soft_prompt = torch.ones((1, self.hidden_size), device=dataset_embeddings.device, dtype=dataset_embeddings.dtype)
@@ -630,8 +598,15 @@ class BiEncoder(transformers.PreTrainedModel):
630
  [d1, d2, d3, hn1_1, hn1_2, hn2_1, hn2_2, hn3_1, hn3_2]
631
  for a corpus with three documents and two hard negatives per document
632
  """
 
 
633
  del token_type_ids
634
 
 
 
 
 
 
635
  outputs = (
636
  self.embedder(
637
  input_ids=input_ids,
@@ -801,7 +776,6 @@ class DatasetConditionedAutoregressive(transformers.PreTrainedModel, ContextualM
801
  return output
802
 
803
 
804
-
805
  class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelMixin):
806
  def __init__(
807
  self,
@@ -812,14 +786,12 @@ class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelM
812
  self.backbone = dataset_backbone
813
  self.hidden_size = self.backbone.config.hidden_size
814
  self.hidden_size = dataset_backbone.config.hidden_size
 
 
 
 
815
  self.contextual_init()
816
  self._shift_rotary_embedding()
817
-
818
- self.pool_ignore_contextual_tokens = vars(self.config).get("pool_ignore_contextual_tokens", True)
819
- self.pool_ignore_instruction_tokens = vars(self.config).get("pool_ignore_instruction_tokens", False)
820
-
821
- tokenizer = transformers.AutoTokenizer.from_pretrained(self.config.embedder)
822
- self.pool_instruction_end_id = tokenizer.encode(": ", add_special_tokens=False)[0] # Hardcoded for colon-ending prefixes.
823
 
824
  @property
825
  def num_corpus_tokens(self) -> int:
@@ -848,55 +820,48 @@ class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelM
848
  output_hidden_states: bool = False,
849
  null_dataset_embedding: bool = False,
850
  ) -> torch.Tensor:
 
851
  soft_prompt = self._prepare_dataset_embeddings(
852
  input_ids=input_ids,
853
  dataset_embeddings=dataset_embeddings,
854
  null_dataset_embedding=null_dataset_embedding,
855
  )
 
856
  backbone_attention_mask = torch.ones(
857
  soft_prompt.shape[0:2],
858
  dtype=torch.long,
859
  device=soft_prompt.device,
860
  )
861
  inputs_embeds = self.backbone.embeddings(input_ids) # (b, s) -> (b, s, d)
 
862
  inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d)
863
- input_attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1)
 
 
864
  output = self.backbone(
865
  inputs_embeds=inputs_embeds,
866
- attention_mask=input_attention_mask,
867
  ) # (1, 4 + b + s, d)
868
  # trim soft prompt
869
  output_vectors = output.last_hidden_state
870
 
871
  # use only these tokens
872
  n_soft_prompt_tokens = soft_prompt.shape[1]
 
873
 
874
- if self.pool_ignore_instruction_tokens:
875
- # Denote the end of an instruction with an extra BOS token.
876
- # This is a bit arcane but relies on the fact that there will be a BOS token after the
877
- # instruction, but also there may or may not be a BOS token at the beginning.
878
- instruction_end_idx = (
879
- (input_ids == self.pool_instruction_end_id) &
880
- attention_mask &
881
- (torch.arange(input_ids.shape[1], device=input_ids.device)[None, :] > 0)
882
- ).int().argmax(1)
883
- is_instruction_token_mask = (
884
- torch.arange(input_ids.shape[1], device=input_ids.device)[None, :] <= instruction_end_idx[:, None]
885
- )
886
- # catch edge case where there is no instruction
887
- is_instruction_token_mask = is_instruction_token_mask.where(
888
- (instruction_end_idx > 0)[:, None], torch.zeros_like(is_instruction_token_mask)
889
- )
890
- output_attention_mask = torch.cat((backbone_attention_mask, attention_mask & ~is_instruction_token_mask), dim=1)
891
- else:
892
- output_attention_mask = input_attention_mask
893
 
894
- if self.pool_ignore_contextual_tokens:
895
- output_vectors = output_vectors[:, n_soft_prompt_tokens:, :]
896
- output_attention_mask = output_attention_mask[:, n_soft_prompt_tokens:]
897
  output_pooled = mean_pool(output_vectors, output_attention_mask)
 
898
  # average with original vectors
899
- output = self.output_projection(output_pooled) + output_pooled # (b, d) -> (b, d) / with residual connection
 
 
 
 
 
900
 
901
  if output_hidden_states:
902
  return {
@@ -967,7 +932,7 @@ class ContextualDocumentEmbeddingTransformer(transformers.PreTrainedModel):
967
  ):
968
  super().__init__(config=config)
969
  dataset_backbone, _ = load_embedder_and_tokenizer(
970
- vars(config).get("dataset_backbone", config.embedder)
971
  )
972
 
973
  if config.limit_layers:
@@ -1026,6 +991,8 @@ class ContextualDocumentEmbeddingTransformer(transformers.PreTrainedModel):
1026
  output_hidden_states=output_hidden_states,
1027
  )
1028
 
 
 
1029
  def get_model_class(name: str):
1030
  if name in 'transductive':
1031
  return ContextualDocumentEmbeddingTransformer
@@ -1034,4 +1001,4 @@ def get_model_class(name: str):
1034
  elif name == "dataset_prefix_biencoder":
1035
  return DatasetPrefixBiencoder
1036
  else:
1037
- raise ValueError(f'unknown model cls {name}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Callable, Optional, Tuple
2
+
3
  import copy
4
+ import json
5
  import math
6
  import multiprocessing
7
  import os
 
10
  import torch.nn as nn
11
  import transformers
12
 
 
13
  class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig):
14
  """We create a dummy configuration class that will just set properties
15
  based on whatever kwargs we pass in.
 
29
  continue
30
  super().__init__()
31
 
 
32
  def load_embedder_and_tokenizer(name: str) -> Tuple[
33
  transformers.PreTrainedModel,
34
  transformers.PreTrainedTokenizer
35
  ]:
36
+ assert name is not None, "name must be provided to load_embedder_and_tokenizer"
37
  if name.startswith("nomic") or (name == "bert-base-uncased"):
38
+ model = transformers.AutoModelForMaskedLM.from_pretrained(name, trust_remote_code=True).bert
39
  tokenizer = transformers.AutoTokenizer.from_pretrained(name)
40
  elif name in ["gtr-base", "gtr_base"]:
41
  model = transformers.AutoModel.from_pretrained(
 
80
  # from optimum.bettertransformer import BetterTransformer
81
  # model = BetterTransformer.transform(model)
82
  return model, tokenizer
 
 
83
  def get_world_size() -> int:
84
  try:
85
  return torch.distributed.get_world_size()
 
290
  sub_x = slice_tensor_rows(X, start, end)
291
  if debug_mem_usage: print(f"[maxsim] step {i} cuda mem free/total = {torch.cuda.mem_get_info()}")
292
  if debug_mem_usage: print("[maxsim] sub_x.shape:", sub_x.shape, "//", "y.shape:", y.shape)
293
+ sub_sim = sub_x @ y # TODO – Implement sparse max here to save mem!
294
  sub_sim = sub_sim
295
  if maximize:
296
  sub_max_sim_v, sub_max_sim_i = sub_sim.to_dense().max(dim=-1)
 
443
  f"Set is_causal=False in {disabled_modules} modules from model type {type(model)}"
444
  )
445
 
 
446
  class ContextualModelMixin(nn.Module):
447
  @property
448
  def num_corpus_tokens(self) -> int:
 
482
  # Auto-expand for a batch.
483
  dataset_embeddings = dataset_embeddings[None, :, :] # (b, d) -> (1, b, d)
484
  dataset_embeddings = dataset_embeddings.to(input_ids.device)
 
 
 
485
 
486
  batch_size = input_ids.shape[0]
487
  if (self.transductive_tokens_per_document > 1):
 
500
  dataset_embeddings = dataset_embeddings[R].reshape((batch_size, self.num_corpus_tokens, self.hidden_size))
501
  else:
502
  dataset_embeddings = dataset_embeddings.reshape((1, self.num_corpus_tokens, self.hidden_size))
503
+ # print("reshaped to dataset_embeddings.shape =", dataset_embeddings.shape)
504
 
505
+ if dataset_embeddings.shape[1] > self.num_corpus_tokens:
 
 
 
506
  # If too many dataset embeddings are passed in, just take the first N until
507
  # we have the proper number.
508
  dataset_embeddings = dataset_embeddings[:, :self.num_corpus_tokens, :]
 
524
  null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1)
525
  dataset_embeddings = null_embeddings
526
 
527
+ # print(f"[ContextualModelMixin] dataset_embeddings.shape = {dataset_embeddings.shape}")
528
+
529
  # backbone_max_seq_length = self.backbone.config.max_trained_positions
530
  # assert batch_size + (2 * self.n_soft_prompt + corpus_size) <= backbone_max_seq_length, "too many hard negatives for backbone model"
531
  soft_prompt = torch.ones((1, self.hidden_size), device=dataset_embeddings.device, dtype=dataset_embeddings.dtype)
 
598
  [d1, d2, d3, hn1_1, hn1_2, hn2_1, hn2_2, hn3_1, hn3_2]
599
  for a corpus with three documents and two hard negatives per document
600
  """
601
+ # del dataset_input_ids
602
+ # del dataset_attention_mask
603
  del token_type_ids
604
 
605
+ # from cde.lib.dist import get_rank
606
+ # tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")
607
+ # if get_rank() == 0:
608
+ # breakpoint()
609
+ # torch.distributed.barrier()
610
  outputs = (
611
  self.embedder(
612
  input_ids=input_ids,
 
776
  return output
777
 
778
 
 
779
  class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelMixin):
780
  def __init__(
781
  self,
 
786
  self.backbone = dataset_backbone
787
  self.hidden_size = self.backbone.config.hidden_size
788
  self.hidden_size = dataset_backbone.config.hidden_size
789
+ # self.input_ln = torch.nn.LayerNorm(
790
+ # self.hidden_size,
791
+ # eps=self.backbone.config.layer_norm_epsilon
792
+ # )
793
  self.contextual_init()
794
  self._shift_rotary_embedding()
 
 
 
 
 
 
795
 
796
  @property
797
  def num_corpus_tokens(self) -> int:
 
820
  output_hidden_states: bool = False,
821
  null_dataset_embedding: bool = False,
822
  ) -> torch.Tensor:
823
+ # print(f"[DatasetConditionedBiencoder - 0] input_ids.shape => {input_ids.shape} // dataset_embeddings.shape =", dataset_embeddings.shape)
824
  soft_prompt = self._prepare_dataset_embeddings(
825
  input_ids=input_ids,
826
  dataset_embeddings=dataset_embeddings,
827
  null_dataset_embedding=null_dataset_embedding,
828
  )
829
+ # print(f"[DatasetConditionedBiencoder - 1] soft_prompt.shape => {soft_prompt.shape}")
830
  backbone_attention_mask = torch.ones(
831
  soft_prompt.shape[0:2],
832
  dtype=torch.long,
833
  device=soft_prompt.device,
834
  )
835
  inputs_embeds = self.backbone.embeddings(input_ids) # (b, s) -> (b, s, d)
836
+ # print("[2] inputs_embeds.shape =", inputs_embeds.shape)
837
  inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d)
838
+ # print("[3.a] inputs_embeds.shape =", inputs_embeds.shape)
839
+ attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1)
840
+ # print("[3.b] attention_mask.shape =", attention_mask.shape)
841
  output = self.backbone(
842
  inputs_embeds=inputs_embeds,
843
+ attention_mask=attention_mask,
844
  ) # (1, 4 + b + s, d)
845
  # trim soft prompt
846
  output_vectors = output.last_hidden_state
847
 
848
  # use only these tokens
849
  n_soft_prompt_tokens = soft_prompt.shape[1]
850
+ # print("n_soft_prompt_tokens =", n_soft_prompt_tokens)
851
 
852
+ output_vectors = output.last_hidden_state[:, n_soft_prompt_tokens:, :]
853
+ output_attention_mask = attention_mask[:, n_soft_prompt_tokens:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
854
 
855
+ # print("pooling output_vectors.shape =", output_vectors.shape, "and output_attention_mask.shape =", output_attention_mask.shape)
 
 
856
  output_pooled = mean_pool(output_vectors, output_attention_mask)
857
+
858
  # average with original vectors
859
+ # TODO: Argparse for pooling strategy.
860
+ # output_vectors = torch.cat((soft_prompt_pooled, output_pooled), dim=1) # (b, d) + (b, d) -> (b, 2d)
861
+ # print("output_pooled.shape =", output_pooled.shape)
862
+ output = self.output_projection(output_pooled) # (b, 2d) -> (b, d)
863
+
864
+ # print("returning output.shape =", output.shape)
865
 
866
  if output_hidden_states:
867
  return {
 
932
  ):
933
  super().__init__(config=config)
934
  dataset_backbone, _ = load_embedder_and_tokenizer(
935
+ vars(config).get("dataset_backbone") or config.embedder
936
  )
937
 
938
  if config.limit_layers:
 
991
  output_hidden_states=output_hidden_states,
992
  )
993
 
994
+
995
+
996
  def get_model_class(name: str):
997
  if name in 'transductive':
998
  return ContextualDocumentEmbeddingTransformer
 
1001
  elif name == "dataset_prefix_biencoder":
1002
  return DatasetPrefixBiencoder
1003
  else:
1004
+ raise ValueError(f'unknown model cls {name}')