duzx16
commited on
Commit
·
4de8efe
1
Parent(s):
3a99d79
Change mask positions to batch
Browse files- modeling_chatglm.py +21 -11
modeling_chatglm.py
CHANGED
|
@@ -689,8 +689,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
| 689 |
|
| 690 |
return attention_mask
|
| 691 |
|
| 692 |
-
def get_position_ids(self, input_ids, mask_positions, device,
|
| 693 |
batch_size, seq_length = input_ids.shape
|
|
|
|
|
|
|
| 694 |
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
| 695 |
if self.position_encoding_2d:
|
| 696 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
|
@@ -704,8 +706,8 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
| 704 |
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
| 705 |
else:
|
| 706 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
| 707 |
-
|
| 708 |
-
|
| 709 |
position_ids[context_length:] = mask_positions[i]
|
| 710 |
|
| 711 |
return position_ids
|
|
@@ -939,15 +941,20 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 939 |
|
| 940 |
if position_ids is None:
|
| 941 |
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
| 942 |
-
|
| 943 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 944 |
|
| 945 |
-
mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
|
| 946 |
position_ids = self.get_position_ids(
|
| 947 |
input_ids,
|
| 948 |
mask_positions=mask_positions,
|
| 949 |
device=input_ids.device,
|
| 950 |
-
|
| 951 |
)
|
| 952 |
|
| 953 |
if self.pre_seq_len is not None and attention_mask is not None:
|
|
@@ -1106,10 +1113,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1106 |
) -> dict:
|
| 1107 |
batch_size, seq_length = input_ids.shape
|
| 1108 |
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
| 1109 |
-
mask_token = gMASK if gMASK in input_ids else MASK
|
| 1110 |
-
use_gmask = True if gMASK in input_ids else False
|
| 1111 |
seqs = input_ids.tolist()
|
| 1112 |
-
mask_positions = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1113 |
|
| 1114 |
# only last token for input_ids if past is not None
|
| 1115 |
if past is not None or past_key_values is not None:
|
|
@@ -1152,7 +1162,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1152 |
input_ids,
|
| 1153 |
device=input_ids.device,
|
| 1154 |
mask_positions=mask_positions,
|
| 1155 |
-
|
| 1156 |
)
|
| 1157 |
|
| 1158 |
return {
|
|
|
|
| 689 |
|
| 690 |
return attention_mask
|
| 691 |
|
| 692 |
+
def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None):
|
| 693 |
batch_size, seq_length = input_ids.shape
|
| 694 |
+
if use_gmasks is None:
|
| 695 |
+
use_gmasks = [False] * batch_size
|
| 696 |
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
| 697 |
if self.position_encoding_2d:
|
| 698 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
|
|
|
| 706 |
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
| 707 |
else:
|
| 708 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
| 709 |
+
for i, context_length in enumerate(context_lengths):
|
| 710 |
+
if not use_gmasks[i]:
|
| 711 |
position_ids[context_length:] = mask_positions[i]
|
| 712 |
|
| 713 |
return position_ids
|
|
|
|
| 941 |
|
| 942 |
if position_ids is None:
|
| 943 |
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
| 944 |
+
seqs = input_ids.tolist()
|
| 945 |
+
|
| 946 |
+
mask_positions, use_gmasks = [], []
|
| 947 |
+
for seq in seqs:
|
| 948 |
+
mask_token = gMASK if gMASK in seq else MASK
|
| 949 |
+
use_gmask = mask_token == gMASK
|
| 950 |
+
mask_positions.append(seq.index(mask_token))
|
| 951 |
+
use_gmasks.append(use_gmask)
|
| 952 |
|
|
|
|
| 953 |
position_ids = self.get_position_ids(
|
| 954 |
input_ids,
|
| 955 |
mask_positions=mask_positions,
|
| 956 |
device=input_ids.device,
|
| 957 |
+
use_gmasks=use_gmasks
|
| 958 |
)
|
| 959 |
|
| 960 |
if self.pre_seq_len is not None and attention_mask is not None:
|
|
|
|
| 1113 |
) -> dict:
|
| 1114 |
batch_size, seq_length = input_ids.shape
|
| 1115 |
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
|
|
|
|
|
|
| 1116 |
seqs = input_ids.tolist()
|
| 1117 |
+
mask_positions, use_gmasks = [], []
|
| 1118 |
+
for seq in seqs:
|
| 1119 |
+
mask_token = gMASK if gMASK in seq else MASK
|
| 1120 |
+
use_gmask = mask_token == gMASK
|
| 1121 |
+
mask_positions.append(seq.index(mask_token))
|
| 1122 |
+
use_gmasks.append(use_gmask)
|
| 1123 |
|
| 1124 |
# only last token for input_ids if past is not None
|
| 1125 |
if past is not None or past_key_values is not None:
|
|
|
|
| 1162 |
input_ids,
|
| 1163 |
device=input_ids.device,
|
| 1164 |
mask_positions=mask_positions,
|
| 1165 |
+
use_gmasks=use_gmasks
|
| 1166 |
)
|
| 1167 |
|
| 1168 |
return {
|