xprovence-reranker-bgem3-v1 / modeling_xprovence_hf.py
youssef101's picture
Update modeling_xprovence_hf.py
fae7fc0 verified
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.nn import CrossEntropyLoss, MSELoss
from torch.utils.data import Dataset
from transformers import XLMRobertaPreTrainedModel, XLMRobertaModel, PretrainedConfig, AutoTokenizer
from transformers.modeling_outputs import ModelOutput
from dataclasses import dataclass
from typing import Optional, Union, Tuple, List
import warnings
import numpy as np
from tqdm import tqdm
import string
import spacy
nlp = spacy.load("xx_sent_ud_sm")
@dataclass
class RankingCompressionOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
compression_loss: Optional[torch.FloatTensor] = None
ranking_loss: Optional[torch.FloatTensor] = None
compression_logits: torch.FloatTensor = None
ranking_scores: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
class XProvenceConfig(PretrainedConfig):
model_type = "XProvence"
def __init__(self, **kwargs):
super().__init__(**kwargs)
class XProvence(XLMRobertaPreTrainedModel):
config_class = XProvenceConfig
def __init__(self, config):
super().__init__(config)
num_labels = getattr(config, "num_labels", 2)
self.num_labels = num_labels
self.roberta = XLMRobertaModel(config)
output_dim = config.hidden_size
### RANKING LAYER
self.classifier = nn.Linear(output_dim, num_labels)
drop_out = getattr(config, "cls_dropout", None)
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
self.dropout = nn.Dropout(drop_out)
### COMPRESSION LAYER: another head (initialized randomly)
token_dropout = drop_out
self.token_dropout = nn.Dropout(token_dropout)
self.token_classifier = nn.Linear(
config.hidden_size, 2
) # => hard coded number of labels
self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
self.max_len = config.max_position_embeddings - 4
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
ranking_labels: Optional[torch.LongTensor] = None,
loss_weight: Optional[float] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], RankingCompressionOutput]:
"""simplified forward"""
outputs = self.roberta(
input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
encoder_layer = outputs[0]
# pooled_output = self.pooler(encoder_layer)
pooled_output = outputs['pooler_output']
pooled_output = self.dropout(pooled_output)
ranking_logits = self.classifier(pooled_output)
compression_logits = self.token_classifier(self.token_dropout(encoder_layer))
ranking_scores = ranking_logits[:, 0].squeeze() # select first dim of logits for ranking scores
compression_loss = None
ranking_loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(compression_logits.device)
loss_fct = CrossEntropyLoss()
compression_loss = loss_fct(compression_logits.view(-1, 2), labels.view(-1))
if ranking_labels is not None:
# here ranking labels are scores (from a teacher) we aim to directly distil (pointwise MSE)
ranking_labels = ranking_labels.to(ranking_logits.device)
loss_fct = MSELoss()
ranking_loss = loss_fct(ranking_scores, ranking_labels.squeeze())
loss = None
if (labels is not None) and (ranking_labels is not None):
w = loss_weight if loss_weight else 1
loss = compression_loss + w * ranking_loss
elif labels is not None:
loss = compression_loss
elif ranking_labels is not None:
loss = ranking_loss
return RankingCompressionOutput(
loss=loss,
compression_loss=compression_loss,
ranking_loss=ranking_loss,
compression_logits=compression_logits,
ranking_scores=ranking_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def process(
self,
question: Union[List[str], str],
context: Union[List[List[str]], str],
title: Optional[Union[List[List[str]], str]] = "first_sentence",
batch_size=32,
threshold=0.3,
always_select_title=False,
reorder=False,
top_k=5,
enable_warnings=True,
):
# convert input format into queries of type List[str] and contexts/titles of type List[List[str]]
if type(question) == str:
queries = [question]
else: # list of strs
queries = question
if type(context) == str:
contexts = [[context]]
else:
contexts = context
if type(title) == str and title != "first_sentence":
titles = [[title]]
else:
titles = title
assert (
titles == "first_sentence"
or titles == None
or type(titles) == list
and len(titles) == len(queries)
), "Variable 'titles' must be 'first_sentence' or a list of strings of the same length as 'queries'"
if type(titles) == list:
assert all(
[
len(titles_item) == len(contexts_item)
for titles_item, contexts_item in zip(contexts, titles)
]
), "Each list in 'titles' must have the same length as the corresponding list in 'context'"
assert len(queries) == len(
contexts
), "Lists 'queries' and 'contexts' must have same lengths"
dataset = TestDataset(
queries=queries,
contexts=contexts,
titles=titles,
tokenizer=self.tokenizer,
max_len=self.max_len,
enable_warnings=enable_warnings,
)
selected_contexts = [
[{0: contexts[i][j]} for j in range(len(contexts[i]))]
for i in range(len(queries))
]
reranking_scores = [
[None for j in range(len(contexts[i]))] for i in range(len(queries))
]
compressions = [
[0 for j in range(len(contexts[i]))] for i in range(len(queries))
]
with torch.no_grad():
for batch_start in tqdm(
range(0, len(dataset), batch_size), desc="Pruning contexts..."
):
qis = dataset.qis[batch_start : batch_start + batch_size]
cis = dataset.cis[batch_start : batch_start + batch_size]
sis = dataset.sis[batch_start : batch_start + batch_size]
sent_coords = dataset.sent_coords[
batch_start : batch_start + batch_size
]
ids_list = dataset.ids[batch_start : batch_start + batch_size]
ids = pad_sequence(
ids_list, batch_first=True, padding_value=dataset.pad_idx
).to(self.device)
mask = (ids != dataset.pad_idx).to(self.device)
outputs = self.forward(ids, mask)
scores = F.softmax(outputs["compression_logits"].cpu(), dim=-1)[:, :, 1]
token_preds = scores > threshold
reranking_scrs = (
outputs["ranking_scores"].cpu().numpy()
) # get first score
if len(reranking_scrs.shape) == 0:
reranking_scrs = reranking_scrs[None]
for (
ids_list_,
token_preds_,
rerank_score,
qi,
ci,
si,
sent_coords_,
) in zip(
ids_list, token_preds, reranking_scrs, qis, cis, sis, sent_coords
):
selected_mask = sentence_rounding(
token_preds_.cpu().numpy(),
np.array(sent_coords_),
threshold=threshold,
always_select_title=always_select_title
and si == 0
and titles != None,
)
assert len(selected_mask) == len(token_preds_)
selected_contexts[qi][ci][si] = ids_list_[
selected_mask[: len(ids_list_)]
]
if si == 0:
reranking_scores[qi][ci] = rerank_score
for i in range(len(queries)):
for j in range(len(contexts[i])):
if type(selected_contexts[i][j][0]) != str:
toks = torch.cat(
[
ids_
for _, ids_ in sorted(
selected_contexts[i][j].items(), key=lambda x: x[0]
)
]
)
selected_contexts[i][j] = self.tokenizer.decode(
toks,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
else:
selected_contexts[i][j] = selected_contexts[i][j][0]
len_original = len(contexts[i][j])
len_compressed = len(selected_contexts[i][j])
compressions[i][j] = (len_original-len_compressed)/len_original * 100
if reorder:
idxs = np.argsort(reranking_scores[i])[::-1][:top_k]
selected_contexts[i] = [selected_contexts[i][j] for j in idxs]
reranking_scores[i] = [reranking_scores[i][j] for j in idxs]
compressions[i] = [compressions[i][j] for j in idxs]
if type(context) == str:
selected_contexts = selected_contexts[0][0]
reranking_scores = reranking_scores[0][0]
compressions = compressions[0][0]
return {
"pruned_context": selected_contexts,
"reranking_score": reranking_scores,
"compression_rate": compressions,
}
# Some utils functions
def sentence_rounding(predictions, chunks, threshold, always_select_title=True):
"""
predictions: a binary vector containing 1 for tokens which were selected and 0s otherwise
chunks: a list of pairs [start, end] of sentence, i.e. sentence is in coordinates predictions[start:end]
the functions
"""
cumulative_sum = np.cumsum(predictions)
chunk_sums = cumulative_sum[chunks[:, 1] - 1] - np.where(
chunks[:, 0] > 0, cumulative_sum[chunks[:, 0] - 1], 0
)
chunk_lengths = chunks[:, 1] - chunks[:, 0]
chunk_means = chunk_sums / chunk_lengths
if always_select_title and (chunk_means>threshold).any():
chunk_means[0] = 1
means = np.hstack((np.zeros(1), chunk_means, np.zeros(1)))
repeats = np.hstack(
([chunks[0][0]], chunk_lengths, [predictions.shape[0] - chunks[-1][1]])
)
return np.repeat(means, repeats) > threshold
def normalize(s: str) -> str:
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_punc(lower(s)))
def sent_split_and_tokenize(text, tokenizer, max_len):
# sents_nltk = nltk.sent_tokenize(text)
sents_nltk = [sent.text.strip() for sent in nlp(text).sents]
sents = []
for j, sent_nltk in enumerate(sents_nltk):
tokinput = (" " if j != 0 else "") + sent_nltk
tok = tokenizer.encode(tokinput, add_special_tokens=False)
ltok = len(tok)
if ltok == 0:
continue
if ltok <= max_len:
sents.append(tok)
else:
for begin in range(0, ltok, max_len):
sents.append(tok[begin:begin+max_len])
return sents
class TestDataset(Dataset):
def __init__(
self,
queries,
contexts,
tokenizer,
max_len=6000,
titles="first_sentence",
enable_warnings=True,
):
self.tokenizer = tokenizer
self.max_len = max_len
self.pad_idx = self.tokenizer.pad_token_id
self.cls_idx = [self.tokenizer.cls_token_id]
self.sep_idx = [self.tokenizer.sep_token_id]
self.eos = [self.tokenizer.eos_token_id]
# hardcoded deberta-specific indexes
self.nb_spe_tok = len(self.cls_idx) + len(self.sep_idx)
self.enable_warnings = enable_warnings
self.unusual_query_length = (
self.max_len // 2
) # TODO: change to data-driven value
self.unusual_title_len = self.max_len // 2 # TODO: change to data-driven value
self.create_dataset(queries, contexts, titles)
self.len = len(self.cis)
def create_dataset(self, queries, contexts, titles="first_sentence"):
self.qis = []
self.cis = []
self.sis = []
self.sent_coords = []
self.cntx_coords = []
self.ids = []
if self.enable_warnings:
warnings_dict = {
"zero_len_query": set(),
"too_long_query": set(),
"unusually_long_query": set(),
"unusually_long_title": set(),
"split_context": set(),
}
for i, query in enumerate(queries):
tokenized_query = self.tokenizer.encode(
normalize(query), add_special_tokens=False
)
# normalize query because all training data has normalized queries
query_len = len(tokenized_query)
if query_len == 0:
if self.enable_warnings:
warnings_dict["zero_len_query"].add(i)
continue
elif query_len >= self.max_len - self.nb_spe_tok - 1: # -1 for eos
if self.enable_warnings:
warnings_dict["too_long_query"].add(i)
continue
elif query_len >= self.unusual_query_length:
if self.enable_warnings:
warnings_dict["unusually_long_query"].add(i)
left_0 = len(tokenized_query) + self.nb_spe_tok
tokenized_seq_0 = self.cls_idx + tokenized_query + self.sep_idx
max_len = self.max_len - left_0 - 1
for j, cntx in enumerate(contexts[i]):
title = titles[i][j] if type(titles) == list else titles
tokenized_sents = sent_split_and_tokenize(cntx, self.tokenizer, max_len)
# each (sent + query + special tokens) <= max_len
if title is not None and title != "first_sentence":
tokenized_title = self.tokenizer.encode(
title, add_special_tokens=False
)
ltok = len(tokenized_title)
if ltok == 0:
pass
elif ltok <= max_len:
tokenized_sents = [tokenized_title] + tokenized_sents
else:
if self.enable_warnings and ltok >= self.unusual_title_len:
warnings_dict["unusually_long_title"].add(i)
tokenized_sents = [
tokenized_title[begin : begin + max_len]
for begin in range(0, ltok, max_len)
] + tokenized_sents
tokenized_seq = tokenized_seq_0
left = left_0
sent_coords = []
block = 0
for idx, tokenized_sent in enumerate(tokenized_sents):
l = len(tokenized_sent)
if left + l <= self.max_len - 1:
sent_coords.append([left, left + l])
tokenized_seq = tokenized_seq + tokenized_sent
left += l
else:
if self.enable_warnings:
warnings_dict["split_context"].add(i)
if len(tokenized_seq) > left_0:
tokenized_seq = tokenized_seq + self.eos
self.qis.append(i)
self.cis.append(j)
self.sis.append(block)
self.sent_coords.append(sent_coords)
self.cntx_coords.append(
[sent_coords[0][0], sent_coords[-1][1]]
)
self.ids.append(torch.tensor(tokenized_seq))
tokenized_seq = tokenized_seq_0 + tokenized_sent
sent_coords = [[left_0, left_0 + l]]
left = left_0 + l
block += 1
if len(tokenized_seq) > left_0:
tokenized_seq = tokenized_seq + self.eos
self.qis.append(i)
self.cis.append(j)
self.sis.append(block)
self.sent_coords.append(sent_coords)
self.cntx_coords.append([sent_coords[0][0], sent_coords[-1][1]])
self.ids.append(torch.tensor(tokenized_seq))
if self.enable_warnings:
self.print_warnings(warnings_dict, len(queries))
def __len__(self):
return len(self.ids)
def print_warnings(self, warnings_dict, N):
n = len(warnings_dict["zero_len_query"])
info = " You can suppress Provence warnings by setting enable_warnings=False."
if n > 0:
ex = list(warnings_dict["zero_len_query"])[:10]
warnings.warn(
f"{n} out of {N} queries have zero length, e.g. at indexes {ex}. "
"These examples will be skipped in context pruning, "
"their contexts will be kept as is." + info
)
n = len(warnings_dict["too_long_query"])
if n > 0:
ex = list(warnings_dict["too_long_query"])[:10]
warnings.warn(
f"{n} out of {N} queries are too long for context length {self.max_len}, "
f"e.g. at indexes {ex}. These examples will be skipped in context pruning, "
"their contexts will be kept as is." + info
)
n = len(warnings_dict["unusually_long_query"])
if n > 0:
ex = list(warnings_dict["unusually_long_query"])[:10]
warnings.warn(
f"{n} out of {N} queries are longer than {self.unusual_query_length} tokens, "
f"e.g. at indexes {ex}. These examples will processed as usual in context pruning, "
"but the quality of context pruning could be reduced." + info
)
n = len(warnings_dict["unusually_long_title"])
if n > 0:
ex = list(warnings_dict["unusually_long_title"])[:10]
warnings.warn(
f"{n} out of {N} titles are longer than {self.unusual_title_length} tokens, "
f"e.g. at indexes {ex}. These examples will processed as usual in context pruning, "
"but the quality of context pruning could be reduced." + info
)
n = len(warnings_dict["split_context"])
if n > 0:
ex = list(warnings_dict["split_context"])[:10]
warnings.warn(
f"{n} out of {N} contexts were split into several pieces for context pruning, "
f"due to a limited context length of Provence which is equal to {self.max_len}. "
"This could potentially reduce the quality of context pruning. "
"You could consider checking and reducing lengths of contexts, queries, or titles."
+ info
)