import torch from torch import nn import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence from torch.nn import CrossEntropyLoss, MSELoss from torch.utils.data import Dataset from transformers import XLMRobertaPreTrainedModel, XLMRobertaModel, PretrainedConfig, AutoTokenizer from transformers.modeling_outputs import ModelOutput from dataclasses import dataclass from typing import Optional, Union, Tuple, List import warnings import numpy as np from tqdm import tqdm import string import spacy nlp = spacy.load("xx_sent_ud_sm") @dataclass class RankingCompressionOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None compression_loss: Optional[torch.FloatTensor] = None ranking_loss: Optional[torch.FloatTensor] = None compression_logits: torch.FloatTensor = None ranking_scores: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None class XProvenceConfig(PretrainedConfig): model_type = "XProvence" def __init__(self, **kwargs): super().__init__(**kwargs) class XProvence(XLMRobertaPreTrainedModel): config_class = XProvenceConfig def __init__(self, config): super().__init__(config) num_labels = getattr(config, "num_labels", 2) self.num_labels = num_labels self.roberta = XLMRobertaModel(config) output_dim = config.hidden_size ### RANKING LAYER self.classifier = nn.Linear(output_dim, num_labels) drop_out = getattr(config, "cls_dropout", None) drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out self.dropout = nn.Dropout(drop_out) ### COMPRESSION LAYER: another head (initialized randomly) token_dropout = drop_out self.token_dropout = nn.Dropout(token_dropout) self.token_classifier = nn.Linear( config.hidden_size, 2 ) # => hard coded number of labels self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path) self.max_len = config.max_position_embeddings - 4 # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, ranking_labels: Optional[torch.LongTensor] = None, loss_weight: Optional[float] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], RankingCompressionOutput]: """simplified forward""" outputs = self.roberta( input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) encoder_layer = outputs[0] # pooled_output = self.pooler(encoder_layer) pooled_output = outputs['pooler_output'] pooled_output = self.dropout(pooled_output) ranking_logits = self.classifier(pooled_output) compression_logits = self.token_classifier(self.token_dropout(encoder_layer)) ranking_scores = ranking_logits[:, 0].squeeze() # select first dim of logits for ranking scores compression_loss = None ranking_loss = None if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(compression_logits.device) loss_fct = CrossEntropyLoss() compression_loss = loss_fct(compression_logits.view(-1, 2), labels.view(-1)) if ranking_labels is not None: # here ranking labels are scores (from a teacher) we aim to directly distil (pointwise MSE) ranking_labels = ranking_labels.to(ranking_logits.device) loss_fct = MSELoss() ranking_loss = loss_fct(ranking_scores, ranking_labels.squeeze()) loss = None if (labels is not None) and (ranking_labels is not None): w = loss_weight if loss_weight else 1 loss = compression_loss + w * ranking_loss elif labels is not None: loss = compression_loss elif ranking_labels is not None: loss = ranking_loss return RankingCompressionOutput( loss=loss, compression_loss=compression_loss, ranking_loss=ranking_loss, compression_logits=compression_logits, ranking_scores=ranking_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def process( self, question: Union[List[str], str], context: Union[List[List[str]], str], title: Optional[Union[List[List[str]], str]] = "first_sentence", batch_size=32, threshold=0.3, always_select_title=False, reorder=False, top_k=5, enable_warnings=True, ): # convert input format into queries of type List[str] and contexts/titles of type List[List[str]] if type(question) == str: queries = [question] else: # list of strs queries = question if type(context) == str: contexts = [[context]] else: contexts = context if type(title) == str and title != "first_sentence": titles = [[title]] else: titles = title assert ( titles == "first_sentence" or titles == None or type(titles) == list and len(titles) == len(queries) ), "Variable 'titles' must be 'first_sentence' or a list of strings of the same length as 'queries'" if type(titles) == list: assert all( [ len(titles_item) == len(contexts_item) for titles_item, contexts_item in zip(contexts, titles) ] ), "Each list in 'titles' must have the same length as the corresponding list in 'context'" assert len(queries) == len( contexts ), "Lists 'queries' and 'contexts' must have same lengths" dataset = TestDataset( queries=queries, contexts=contexts, titles=titles, tokenizer=self.tokenizer, max_len=self.max_len, enable_warnings=enable_warnings, ) selected_contexts = [ [{0: contexts[i][j]} for j in range(len(contexts[i]))] for i in range(len(queries)) ] reranking_scores = [ [None for j in range(len(contexts[i]))] for i in range(len(queries)) ] compressions = [ [0 for j in range(len(contexts[i]))] for i in range(len(queries)) ] with torch.no_grad(): for batch_start in tqdm( range(0, len(dataset), batch_size), desc="Pruning contexts..." ): qis = dataset.qis[batch_start : batch_start + batch_size] cis = dataset.cis[batch_start : batch_start + batch_size] sis = dataset.sis[batch_start : batch_start + batch_size] sent_coords = dataset.sent_coords[ batch_start : batch_start + batch_size ] ids_list = dataset.ids[batch_start : batch_start + batch_size] ids = pad_sequence( ids_list, batch_first=True, padding_value=dataset.pad_idx ).to(self.device) mask = (ids != dataset.pad_idx).to(self.device) outputs = self.forward(ids, mask) scores = F.softmax(outputs["compression_logits"].cpu(), dim=-1)[:, :, 1] token_preds = scores > threshold reranking_scrs = ( outputs["ranking_scores"].cpu().numpy() ) # get first score if len(reranking_scrs.shape) == 0: reranking_scrs = reranking_scrs[None] for ( ids_list_, token_preds_, rerank_score, qi, ci, si, sent_coords_, ) in zip( ids_list, token_preds, reranking_scrs, qis, cis, sis, sent_coords ): selected_mask = sentence_rounding( token_preds_.cpu().numpy(), np.array(sent_coords_), threshold=threshold, always_select_title=always_select_title and si == 0 and titles != None, ) assert len(selected_mask) == len(token_preds_) selected_contexts[qi][ci][si] = ids_list_[ selected_mask[: len(ids_list_)] ] if si == 0: reranking_scores[qi][ci] = rerank_score for i in range(len(queries)): for j in range(len(contexts[i])): if type(selected_contexts[i][j][0]) != str: toks = torch.cat( [ ids_ for _, ids_ in sorted( selected_contexts[i][j].items(), key=lambda x: x[0] ) ] ) selected_contexts[i][j] = self.tokenizer.decode( toks, skip_special_tokens=True, clean_up_tokenization_spaces=False, ) else: selected_contexts[i][j] = selected_contexts[i][j][0] len_original = len(contexts[i][j]) len_compressed = len(selected_contexts[i][j]) compressions[i][j] = (len_original-len_compressed)/len_original * 100 if reorder: idxs = np.argsort(reranking_scores[i])[::-1][:top_k] selected_contexts[i] = [selected_contexts[i][j] for j in idxs] reranking_scores[i] = [reranking_scores[i][j] for j in idxs] compressions[i] = [compressions[i][j] for j in idxs] if type(context) == str: selected_contexts = selected_contexts[0][0] reranking_scores = reranking_scores[0][0] compressions = compressions[0][0] return { "pruned_context": selected_contexts, "reranking_score": reranking_scores, "compression_rate": compressions, } # Some utils functions def sentence_rounding(predictions, chunks, threshold, always_select_title=True): """ predictions: a binary vector containing 1 for tokens which were selected and 0s otherwise chunks: a list of pairs [start, end] of sentence, i.e. sentence is in coordinates predictions[start:end] the functions """ cumulative_sum = np.cumsum(predictions) chunk_sums = cumulative_sum[chunks[:, 1] - 1] - np.where( chunks[:, 0] > 0, cumulative_sum[chunks[:, 0] - 1], 0 ) chunk_lengths = chunks[:, 1] - chunks[:, 0] chunk_means = chunk_sums / chunk_lengths if always_select_title and (chunk_means>threshold).any(): chunk_means[0] = 1 means = np.hstack((np.zeros(1), chunk_means, np.zeros(1))) repeats = np.hstack( ([chunks[0][0]], chunk_lengths, [predictions.shape[0] - chunks[-1][1]]) ) return np.repeat(means, repeats) > threshold def normalize(s: str) -> str: def white_space_fix(text): return " ".join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return "".join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_punc(lower(s))) def sent_split_and_tokenize(text, tokenizer, max_len): # sents_nltk = nltk.sent_tokenize(text) sents_nltk = [sent.text.strip() for sent in nlp(text).sents] sents = [] for j, sent_nltk in enumerate(sents_nltk): tokinput = (" " if j != 0 else "") + sent_nltk tok = tokenizer.encode(tokinput, add_special_tokens=False) ltok = len(tok) if ltok == 0: continue if ltok <= max_len: sents.append(tok) else: for begin in range(0, ltok, max_len): sents.append(tok[begin:begin+max_len]) return sents class TestDataset(Dataset): def __init__( self, queries, contexts, tokenizer, max_len=6000, titles="first_sentence", enable_warnings=True, ): self.tokenizer = tokenizer self.max_len = max_len self.pad_idx = self.tokenizer.pad_token_id self.cls_idx = [self.tokenizer.cls_token_id] self.sep_idx = [self.tokenizer.sep_token_id] self.eos = [self.tokenizer.eos_token_id] # hardcoded deberta-specific indexes self.nb_spe_tok = len(self.cls_idx) + len(self.sep_idx) self.enable_warnings = enable_warnings self.unusual_query_length = ( self.max_len // 2 ) # TODO: change to data-driven value self.unusual_title_len = self.max_len // 2 # TODO: change to data-driven value self.create_dataset(queries, contexts, titles) self.len = len(self.cis) def create_dataset(self, queries, contexts, titles="first_sentence"): self.qis = [] self.cis = [] self.sis = [] self.sent_coords = [] self.cntx_coords = [] self.ids = [] if self.enable_warnings: warnings_dict = { "zero_len_query": set(), "too_long_query": set(), "unusually_long_query": set(), "unusually_long_title": set(), "split_context": set(), } for i, query in enumerate(queries): tokenized_query = self.tokenizer.encode( normalize(query), add_special_tokens=False ) # normalize query because all training data has normalized queries query_len = len(tokenized_query) if query_len == 0: if self.enable_warnings: warnings_dict["zero_len_query"].add(i) continue elif query_len >= self.max_len - self.nb_spe_tok - 1: # -1 for eos if self.enable_warnings: warnings_dict["too_long_query"].add(i) continue elif query_len >= self.unusual_query_length: if self.enable_warnings: warnings_dict["unusually_long_query"].add(i) left_0 = len(tokenized_query) + self.nb_spe_tok tokenized_seq_0 = self.cls_idx + tokenized_query + self.sep_idx max_len = self.max_len - left_0 - 1 for j, cntx in enumerate(contexts[i]): title = titles[i][j] if type(titles) == list else titles tokenized_sents = sent_split_and_tokenize(cntx, self.tokenizer, max_len) # each (sent + query + special tokens) <= max_len if title is not None and title != "first_sentence": tokenized_title = self.tokenizer.encode( title, add_special_tokens=False ) ltok = len(tokenized_title) if ltok == 0: pass elif ltok <= max_len: tokenized_sents = [tokenized_title] + tokenized_sents else: if self.enable_warnings and ltok >= self.unusual_title_len: warnings_dict["unusually_long_title"].add(i) tokenized_sents = [ tokenized_title[begin : begin + max_len] for begin in range(0, ltok, max_len) ] + tokenized_sents tokenized_seq = tokenized_seq_0 left = left_0 sent_coords = [] block = 0 for idx, tokenized_sent in enumerate(tokenized_sents): l = len(tokenized_sent) if left + l <= self.max_len - 1: sent_coords.append([left, left + l]) tokenized_seq = tokenized_seq + tokenized_sent left += l else: if self.enable_warnings: warnings_dict["split_context"].add(i) if len(tokenized_seq) > left_0: tokenized_seq = tokenized_seq + self.eos self.qis.append(i) self.cis.append(j) self.sis.append(block) self.sent_coords.append(sent_coords) self.cntx_coords.append( [sent_coords[0][0], sent_coords[-1][1]] ) self.ids.append(torch.tensor(tokenized_seq)) tokenized_seq = tokenized_seq_0 + tokenized_sent sent_coords = [[left_0, left_0 + l]] left = left_0 + l block += 1 if len(tokenized_seq) > left_0: tokenized_seq = tokenized_seq + self.eos self.qis.append(i) self.cis.append(j) self.sis.append(block) self.sent_coords.append(sent_coords) self.cntx_coords.append([sent_coords[0][0], sent_coords[-1][1]]) self.ids.append(torch.tensor(tokenized_seq)) if self.enable_warnings: self.print_warnings(warnings_dict, len(queries)) def __len__(self): return len(self.ids) def print_warnings(self, warnings_dict, N): n = len(warnings_dict["zero_len_query"]) info = " You can suppress Provence warnings by setting enable_warnings=False." if n > 0: ex = list(warnings_dict["zero_len_query"])[:10] warnings.warn( f"{n} out of {N} queries have zero length, e.g. at indexes {ex}. " "These examples will be skipped in context pruning, " "their contexts will be kept as is." + info ) n = len(warnings_dict["too_long_query"]) if n > 0: ex = list(warnings_dict["too_long_query"])[:10] warnings.warn( f"{n} out of {N} queries are too long for context length {self.max_len}, " f"e.g. at indexes {ex}. These examples will be skipped in context pruning, " "their contexts will be kept as is." + info ) n = len(warnings_dict["unusually_long_query"]) if n > 0: ex = list(warnings_dict["unusually_long_query"])[:10] warnings.warn( f"{n} out of {N} queries are longer than {self.unusual_query_length} tokens, " f"e.g. at indexes {ex}. These examples will processed as usual in context pruning, " "but the quality of context pruning could be reduced." + info ) n = len(warnings_dict["unusually_long_title"]) if n > 0: ex = list(warnings_dict["unusually_long_title"])[:10] warnings.warn( f"{n} out of {N} titles are longer than {self.unusual_title_length} tokens, " f"e.g. at indexes {ex}. These examples will processed as usual in context pruning, " "but the quality of context pruning could be reduced." + info ) n = len(warnings_dict["split_context"]) if n > 0: ex = list(warnings_dict["split_context"])[:10] warnings.warn( f"{n} out of {N} contexts were split into several pieces for context pruning, " f"due to a limited context length of Provence which is equal to {self.max_len}. " "This could potentially reduce the quality of context pruning. " "You could consider checking and reducing lengths of contexts, queries, or titles." + info )