|
import os |
|
import torch |
|
import collections |
|
import logging |
|
from tqdm import tqdm, trange |
|
import json |
|
import bs4 |
|
from os import path as osp |
|
from bs4 import BeautifulSoup as bs |
|
|
|
from torch.utils.data import Dataset |
|
import networkx as nx |
|
from lxml import etree |
|
import pickle |
|
|
|
from transformers import BertTokenizer |
|
import argparse |
|
|
|
tags_dict = {'a': 0, 'abbr': 1, 'acronym': 2, 'address': 3, 'altGlyph': 4, 'altGlyphDef': 5, 'altGlyphItem': 6, |
|
'animate': 7, 'animateColor': 8, 'animateMotion': 9, 'animateTransform': 10, 'applet': 11, 'area': 12, |
|
'article': 13, 'aside': 14, 'audio': 15, 'b': 16, 'base': 17, 'basefont': 18, 'bdi': 19, 'bdo': 20, |
|
'bgsound': 21, 'big': 22, 'blink': 23, 'blockquote': 24, 'body': 25, 'br': 26, 'button': 27, 'canvas': 28, |
|
'caption': 29, 'center': 30, 'circle': 31, 'cite': 32, 'clipPath': 33, 'code': 34, 'col': 35, |
|
'colgroup': 36, 'color-profile': 37, 'content': 38, 'cursor': 39, 'data': 40, 'datalist': 41, 'dd': 42, |
|
'defs': 43, 'del': 44, 'desc': 45, 'details': 46, 'dfn': 47, 'dialog': 48, 'dir': 49, 'div': 50, 'dl': 51, |
|
'dt': 52, 'ellipse': 53, 'em': 54, 'embed': 55, 'feBlend': 56, 'feColorMatrix': 57, |
|
'feComponentTransfer': 58, 'feComposite': 59, 'feConvolveMatrix': 60, 'feDiffuseLighting': 61, |
|
'feDisplacementMap': 62, 'feDistantLight': 63, 'feFlood': 64, 'feFuncA': 65, 'feFuncB': 66, 'feFuncG': 67, |
|
'feFuncR': 68, 'feGaussianBlur': 69, 'feImage': 70, 'feMerge': 71, 'feMergeNode': 72, 'feMorphology': 73, |
|
'feOffset': 74, 'fePointLight': 75, 'feSpecularLighting': 76, 'feSpotLight': 77, 'feTile': 78, |
|
'feTurbulence': 79, 'fieldset': 80, 'figcaption': 81, 'figure': 82, 'filter': 83, 'font-face-format': 84, |
|
'font-face-name': 85, 'font-face-src': 86, 'font-face-uri': 87, 'font-face': 88, 'font': 89, 'footer': 90, |
|
'foreignObject': 91, 'form': 92, 'frame': 93, 'frameset': 94, 'g': 95, 'glyph': 96, 'glyphRef': 97, |
|
'h1': 98, 'h2': 99, 'h3': 100, 'h4': 101, 'h5': 102, 'h6': 103, 'head': 104, 'header': 105, 'hgroup': 106, |
|
'hkern': 107, 'hr': 108, 'html': 109, 'i': 110, 'iframe': 111, 'image': 112, 'img': 113, 'input': 114, |
|
'ins': 115, 'kbd': 116, 'keygen': 117, 'label': 118, 'legend': 119, 'li': 120, 'line': 121, |
|
'linearGradient': 122, 'link': 123, 'main': 124, 'map': 125, 'mark': 126, 'marker': 127, 'marquee': 128, |
|
'mask': 129, 'math': 130, 'menu': 131, 'menuitem': 132, 'meta': 133, 'metadata': 134, 'meter': 135, |
|
'missing-glyph': 136, 'mpath': 137, 'nav': 138, 'nobr': 139, 'noembed': 140, 'noframes': 141, |
|
'noscript': 142, 'object': 143, 'ol': 144, 'optgroup': 145, 'option': 146, 'output': 147, 'p': 148, |
|
'param': 149, 'path': 150, 'pattern': 151, 'picture': 152, 'plaintext': 153, 'polygon': 154, |
|
'polyline': 155, 'portal': 156, 'pre': 157, 'progress': 158, 'q': 159, 'radialGradient': 160, 'rb': 161, |
|
'rect': 162, 'rp': 163, 'rt': 164, 'rtc': 165, 'ruby': 166, 's': 167, 'samp': 168, 'script': 169, |
|
'section': 170, 'select': 171, 'set': 172, 'shadow': 173, 'slot': 174, 'small': 175, 'source': 176, |
|
'spacer': 177, 'span': 178, 'stop': 179, 'strike': 180, 'strong': 181, 'style': 182, 'sub': 183, |
|
'summary': 184, 'sup': 185, 'svg': 186, 'switch': 187, 'symbol': 188, 'table': 189, 'tbody': 190, |
|
'td': 191, 'template': 192, 'text': 193, 'textPath': 194, 'textarea': 195, 'tfoot': 196, 'th': 197, |
|
'thead': 198, 'time': 199, 'title': 200, 'tr': 201, 'track': 202, 'tref': 203, 'tspan': 204, 'tt': 205, |
|
'u': 206, 'ul': 207, 'use': 208, 'var': 209, 'video': 210, 'view': 211, 'vkern': 212, 'wbr': 213, |
|
'xmp': 214} |
|
|
|
|
|
def whitespace_tokenize(text): |
|
"""Runs basic whitespace cleaning and splitting on a piece of text.""" |
|
text = text.strip() |
|
if not text: |
|
return [] |
|
tokens = text.split() |
|
return tokens |
|
|
|
|
|
|
|
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text): |
|
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) |
|
|
|
for new_start in range(input_start, input_end + 1): |
|
for new_end in range(input_end, new_start - 1, -1): |
|
text_span = " ".join([w for w in doc_tokens[new_start:(new_end + 1)] |
|
if w[0] != '<' or w[-1] != '>']) |
|
if text_span == tok_answer_text: |
|
return new_start, new_end |
|
|
|
return input_start, input_end |
|
|
|
class StrucDataset(Dataset): |
|
"""Dataset wrapping tensors. |
|
|
|
Each sample will be retrieved by indexing tensors along the first dimension. |
|
|
|
Arguments: |
|
*tensors (*torch.Tensor): tensors that have the same size of the first dimension. |
|
page_ids (list): the corresponding page ids of the input features. |
|
cnn_feature_dir (str): the direction where the cnn features are stored. |
|
token_to_tag (torch.Tensor): the mapping from each token to its corresponding tag id. |
|
""" |
|
|
|
def __init__(self, *tensors, pad_id=0, |
|
all_expended_attention_mask=None, |
|
all_graph_names=None, |
|
all_token_to_tag=None, |
|
page_ids=None, |
|
attention_width=None, |
|
has_tree_attention_bias = False): |
|
tensors = tuple(tensor for tensor in tensors) |
|
assert all(len(tensors[0]) == len(tensor) for tensor in tensors) |
|
if all_expended_attention_mask is not None: |
|
assert len(tensors[0]) == len(all_expended_attention_mask) |
|
tensors += (all_expended_attention_mask,) |
|
self.tensors = tensors |
|
self.page_ids = page_ids |
|
self.all_graph_names = all_graph_names |
|
self.all_token_to_tag = all_token_to_tag |
|
self.pad_id = pad_id |
|
self.attention_width = attention_width |
|
self.has_tree_attention_bias = has_tree_attention_bias |
|
|
|
def __getitem__(self, index): |
|
output = [tensor[index] for tensor in self.tensors] |
|
|
|
input_id = output[0] |
|
attention_mask = output[1] |
|
|
|
|
|
if not self.attention_width is None or self.has_tree_attention_bias: |
|
assert self.all_graph_names is not None , ("For non-empty attention_width / tree rel pos," |
|
"Graph names must be sent in!") |
|
|
|
if self.all_graph_names is not None: |
|
assert self.all_token_to_tag is not None |
|
graph_name = self.all_graph_names[index] |
|
token_to_tag = self.all_token_to_tag[index] |
|
with open(graph_name,"rb") as f: |
|
node_pairs_lengths = pickle.load(f) |
|
|
|
|
|
|
|
seq_len = len(token_to_tag) |
|
if self.has_tree_attention_bias: |
|
mat = [[0]*seq_len]*seq_len |
|
else: |
|
mat = None |
|
|
|
if self.attention_width is not None: |
|
emask = attention_mask.expand(seq_len,seq_len) |
|
else: |
|
emask = None |
|
|
|
for nid in range(seq_len): |
|
if input_id[nid]==self.pad_id: |
|
break |
|
for anid in range(nid+1,seq_len): |
|
if input_id[anid]==self.pad_id: |
|
break |
|
|
|
x_tid4nid = token_to_tag[nid] |
|
x_tid4anid = token_to_tag[anid] |
|
|
|
if x_tid4nid==x_tid4anid: |
|
continue |
|
|
|
try: |
|
xx = node_pairs_lengths[x_tid4nid] |
|
|
|
except: |
|
|
|
xx = node_pairs_lengths[-1] |
|
x_tid4nid=-1 |
|
|
|
try: |
|
dis = xx[x_tid4anid] |
|
|
|
except: |
|
|
|
dis = xx[-1] |
|
x_tid4anid = -1 |
|
|
|
|
|
|
|
|
|
if self.has_tree_attention_bias: |
|
if x_tid4nid<x_tid4anid: |
|
mat[nid][anid]=dis |
|
mat[anid][nid]=-dis |
|
else: |
|
mat[nid][anid] = -dis |
|
mat[anid][nid] = dis |
|
|
|
if self.attention_width is not None: |
|
|
|
if x_tid4nid==-1 or x_tid4anid==-1: |
|
continue |
|
|
|
if dis>self.attention_width: |
|
emask[nid][anid]=0 |
|
emask[anid][nid]=0 |
|
|
|
|
|
if self.attention_width is not None: |
|
output.append(emask) |
|
|
|
if self.has_tree_attention_bias: |
|
t_mat = torch.tensor(mat,dtype=torch.long) |
|
output.append(t_mat) |
|
|
|
|
|
return tuple(item for item in output) |
|
|
|
def __len__(self): |
|
return len(self.tensors[0]) |
|
|
|
|
|
def get_xpath4tokens(html_fn: str, unique_tids: set): |
|
xpath_map = {} |
|
tree = etree.parse(html_fn, etree.HTMLParser()) |
|
nodes = tree.xpath('//*') |
|
for node in nodes: |
|
tid = node.attrib.get("tid") |
|
if int(tid) in unique_tids: |
|
xpath_map[int(tid)] = tree.getpath(node) |
|
xpath_map[len(nodes)] = "/html" |
|
xpath_map[len(nodes) + 1] = "/html" |
|
return xpath_map |
|
|
|
|
|
def get_xpath_and_treeid4tokens(html_code, unique_tids, max_depth): |
|
unknown_tag_id = len(tags_dict) |
|
pad_tag_id = unknown_tag_id + 1 |
|
max_width = 1000 |
|
width_pad_id = 1001 |
|
|
|
pad_x_tag_seq = [pad_tag_id] * max_depth |
|
pad_x_subs_seq = [width_pad_id] * max_depth |
|
pad_x_box = [0,0,0,0] |
|
pad_tree_id_seq = [width_pad_id] * max_depth |
|
|
|
def xpath_soup(element): |
|
|
|
xpath_tags = [] |
|
xpath_subscripts = [] |
|
tree_index = [] |
|
child = element if element.name else element.parent |
|
for parent in child.parents: |
|
siblings = parent.find_all(child.name, recursive=False) |
|
para_siblings = parent.find_all(True, recursive=False) |
|
xpath_tags.append(child.name) |
|
xpath_subscripts.append( |
|
0 if 1 == len(siblings) else next(i for i, s in enumerate(siblings, 1) if s is child)) |
|
|
|
tree_index.append(next(i for i, s in enumerate(para_siblings, 0) if s is child)) |
|
child = parent |
|
xpath_tags.reverse() |
|
xpath_subscripts.reverse() |
|
tree_index.reverse() |
|
return xpath_tags, xpath_subscripts, tree_index |
|
|
|
xpath_tag_map = {} |
|
xpath_subs_map = {} |
|
tree_id_map = {} |
|
|
|
for tid in unique_tids: |
|
element = html_code.find(attrs={'tid': tid}) |
|
if element is None: |
|
xpath_tags = pad_x_tag_seq |
|
xpath_subscripts = pad_x_subs_seq |
|
tree_index = pad_tree_id_seq |
|
|
|
xpath_tag_map[tid] = xpath_tags |
|
xpath_subs_map[tid] = xpath_subscripts |
|
tree_id_map[tid] = tree_index |
|
continue |
|
|
|
xpath_tags, xpath_subscripts, tree_index = xpath_soup(element) |
|
|
|
assert len(xpath_tags) == len(xpath_subscripts) |
|
assert len(xpath_tags) == len(tree_index) |
|
|
|
if len(xpath_tags) > max_depth: |
|
xpath_tags = xpath_tags[-max_depth:] |
|
xpath_subscripts = xpath_subscripts[-max_depth:] |
|
|
|
|
|
xpath_tags = [tags_dict.get(name, unknown_tag_id) for name in xpath_tags] |
|
xpath_subscripts = [min(i, max_width) for i in xpath_subscripts] |
|
tree_index = [min(i, max_width) for i in tree_index] |
|
|
|
|
|
|
|
xpath_tags += [pad_tag_id] * (max_depth - len(xpath_tags)) |
|
xpath_subscripts += [width_pad_id] * (max_depth - len(xpath_subscripts)) |
|
|
|
|
|
xpath_tag_map[tid] = xpath_tags |
|
xpath_subs_map[tid] = xpath_subscripts |
|
tree_id_map[tid] = tree_index |
|
|
|
return xpath_tag_map, xpath_subs_map, tree_id_map |
|
|
|
|
|
|
|
def _check_is_max_context(doc_spans, cur_span_index, position): |
|
best_score = None |
|
best_span_index = None |
|
for (span_index, doc_span) in enumerate(doc_spans): |
|
end = doc_span.start + doc_span.length - 1 |
|
if position < doc_span.start: |
|
continue |
|
if position > end: |
|
continue |
|
num_left_context = position - doc_span.start |
|
num_right_context = end - position |
|
score = min(num_left_context, num_right_context) + 0.01 * doc_span.length |
|
if best_score is None or score > best_score: |
|
best_score = score |
|
best_span_index = span_index |
|
|
|
return cur_span_index == best_span_index |
|
|
|
|
|
class SRCExample(object): |
|
r""" |
|
The Containers for SRC Examples. |
|
|
|
Arguments: |
|
doc_tokens (list[str]): the original tokens of the HTML file before dividing into sub-tokens. |
|
qas_id (str): the id of the corresponding question. |
|
tag_num (int): the total tag number in the corresponding HTML file, including the additional 'yes' and 'no'. |
|
question_text (str): the text of the corresponding question. |
|
orig_answer_text (str): the answer text provided by the dataset. |
|
all_doc_tokens (list[str]): the sub-tokens of the corresponding HTML file. |
|
start_position (int): the position where the answer starts in the all_doc_tokens. |
|
end_position (int): the position where the answer ends in the all_doc_tokens; NOTE that the answer tokens |
|
include the token at end_position. |
|
tok_to_orig_index (list[int]): the mapping from sub-tokens (all_doc_tokens) to origin tokens (doc_tokens). |
|
orig_to_tok_index (list[int]): the mapping from origin tokens (doc_tokens) to sub-tokens (all_doc_tokens). |
|
tok_to_tags_index (list[int]): the mapping from sub-tokens (all_doc_tokens) to the id of the deepest tag it |
|
belongs to. |
|
""" |
|
|
|
|
|
|
|
def __init__(self, |
|
doc_tokens, |
|
qas_id, |
|
tag_num, |
|
question_text=None, |
|
html_code=None, |
|
orig_answer_text=None, |
|
start_position=None, |
|
end_position=None, |
|
tok_to_orig_index=None, |
|
orig_to_tok_index=None, |
|
all_doc_tokens=None, |
|
tok_to_tags_index=None, |
|
xpath_tag_map=None, |
|
xpath_subs_map=None, |
|
xpath_box=None, |
|
tree_id_map=None, |
|
visible_matrix=None, |
|
): |
|
self.doc_tokens = doc_tokens |
|
self.qas_id = qas_id |
|
self.tag_num = tag_num |
|
self.question_text = question_text |
|
self.html_code = html_code |
|
self.orig_answer_text = orig_answer_text |
|
self.start_position = start_position |
|
self.end_position = end_position |
|
self.tok_to_orig_index = tok_to_orig_index |
|
self.orig_to_tok_index = orig_to_tok_index |
|
self.all_doc_tokens = all_doc_tokens |
|
self.tok_to_tags_index = tok_to_tags_index |
|
self.xpath_tag_map = xpath_tag_map |
|
self.xpath_subs_map = xpath_subs_map |
|
self.xpath_box = xpath_box |
|
self.tree_id_map = tree_id_map |
|
self.visible_matrix = visible_matrix |
|
|
|
def __str__(self): |
|
return self.__repr__() |
|
|
|
def __repr__(self): |
|
""" |
|
s = "" |
|
s += "qas_id: %s" % self.qas_id |
|
s += ", question_text: %s" % ( |
|
self.question_text) |
|
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) |
|
if self.start_position: |
|
s += ", start_position: %d" % self.start_position |
|
if self.end_position: |
|
s += ", end_position: %d" % self.end_position |
|
""" |
|
s = "[INFO]\n" |
|
s += f"qas_id ({type(self.qas_id)}): {self.qas_id}\n" |
|
s += f"tag_num ({type(self.tag_num)}): {self.tag_num}\n" |
|
s += f"question_text ({type(self.question_text)}): {self.question_text}\n" |
|
s += f"html_code ({type(self.html_code)}): {self.html_code}\n" |
|
s += f"orig_answer_text ({type(self.orig_answer_text)}): {self.orig_answer_text}\n" |
|
s += f"start_position ({type(self.start_position)}): {self.start_position}\n" |
|
s += f"end_position ({type(self.end_position)}): {self.end_position}\n" |
|
s += f"tok_to_orig_index ({type(self.tok_to_orig_index)}): {self.tok_to_orig_index}\n" |
|
s += f"orig_to_tok_index ({type(self.orig_to_tok_index)}): {self.orig_to_tok_index}\n" |
|
s += f"all_doc_tokens ({type(self.all_doc_tokens)}): {self.all_doc_tokens}\n" |
|
s += f"tok_to_tags_index ({type(self.tok_to_tags_index)}): {self.tok_to_tags_index}\n" |
|
s += f"xpath_tag_map ({type(self.xpath_tag_map)}): {self.xpath_tag_map}\n" |
|
s += f"xpath_subs_map ({type(self.xpath_subs_map)}): {self.xpath_subs_map}\n" |
|
s += f"tree_id_map ({type(self.tree_id_map)}): {self.tree_id_map}\n" |
|
|
|
return s |
|
|
|
|
|
|
|
|
|
class InputFeatures(object): |
|
r""" |
|
The Container for the Features of Input Doc Spans. |
|
|
|
Arguments: |
|
unique_id (int): the unique id of the input doc span. |
|
example_index (int): the index of the corresponding SRC Example of the input doc span. |
|
page_id (str): the id of the corresponding web page of the question. |
|
doc_span_index (int): the index of the doc span among all the doc spans which corresponding to the same SRC |
|
Example. |
|
tokens (list[str]): the sub-tokens of the input sequence, including cls token, sep tokens, and the sub-tokens |
|
of the question and HTML file. |
|
token_to_orig_map (dict[int, int]): the mapping from the HTML file's sub-tokens in the sequence tokens (tokens) |
|
to the origin tokens (all_tokens in the corresponding SRC Example). |
|
token_is_max_context (dict[int, bool]): whether the current doc span contains the max pre- and post-context for |
|
each HTML file's sub-tokens. |
|
input_ids (list[int]): the ids of the sub-tokens in the input sequence (tokens). |
|
input_mask (list[int]): use 0/1 to distinguish the input sequence from paddings. |
|
segment_ids (list[int]): use 0/1 to distinguish the question and the HTML files. |
|
paragraph_len (int): the length of the HTML file's sub-tokens. |
|
start_position (int): the position where the answer starts in the input sequence (0 if the answer is not fully |
|
in the input sequence). |
|
end_position (int): the position where the answer ends in the input sequence; NOTE that the answer tokens |
|
include the token at end_position (0 if the answer is not fully in the input sequence). |
|
token_to_tag_index (list[int]): the mapping from sub-tokens of the input sequence to the id of the deepest tag |
|
it belongs to. |
|
is_impossible (bool): whether the answer is fully in the doc span. |
|
""" |
|
|
|
def __init__(self, |
|
unique_id, |
|
example_index, |
|
page_id, |
|
doc_span_index, |
|
tokens, |
|
token_to_orig_map, |
|
token_is_max_context, |
|
input_ids, |
|
input_mask, |
|
segment_ids, |
|
paragraph_len, |
|
start_position=None, |
|
end_position=None, |
|
token_to_tag_index=None, |
|
is_impossible=None, |
|
xpath_tags_seq=None, |
|
xpath_subs_seq=None, |
|
xpath_box_seq=None, |
|
extended_attention_mask=None): |
|
self.unique_id = unique_id |
|
self.example_index = example_index |
|
self.page_id = page_id |
|
self.doc_span_index = doc_span_index |
|
self.tokens = tokens |
|
self.token_to_orig_map = token_to_orig_map |
|
self.token_is_max_context = token_is_max_context |
|
self.input_ids = input_ids |
|
self.input_mask = input_mask |
|
self.segment_ids = segment_ids |
|
self.paragraph_len = paragraph_len |
|
self.start_position = start_position |
|
self.end_position = end_position |
|
self.token_to_tag_index = token_to_tag_index |
|
self.is_impossible = is_impossible |
|
self.xpath_tags_seq = xpath_tags_seq |
|
self.xpath_subs_seq = xpath_subs_seq |
|
self.xpath_box_seq = xpath_box_seq |
|
self.extended_attention_mask = extended_attention_mask |
|
|
|
|
|
def html_escape(html): |
|
r""" |
|
replace the special expressions in the html file for specific punctuation. |
|
""" |
|
html = html.replace('"', '"') |
|
html = html.replace('&', '&') |
|
html = html.replace('<', '<') |
|
html = html.replace('>', '>') |
|
html = html.replace(' ', ' ') |
|
return html |
|
|
|
def read_squad_examples(args, input_file, root_dir, is_training, tokenizer, simplify=False, max_depth=50, |
|
split_flag="n-eon", |
|
attention_width=None): |
|
r""" |
|
pre-process the data in json format into SRC Examples. |
|
|
|
Arguments: |
|
split_flag: |
|
attention_width: |
|
input_file (str): the inputting data file in json format. |
|
root_dir (str): the root directory of the raw WebSRC dataset, which contains the HTML files. |
|
is_training (bool): True if processing the training set, else False. |
|
tokenizer (Tokenizer): the tokenizer for PLM in use. |
|
method (str): the name of the method in use, choice: ['T-PLM', 'H-PLM', 'V-PLM']. |
|
simplify (bool): when setting to Ture, the returned Example will only contain document tokens, the id of the |
|
question-answers, and the total tag number in the corresponding html files. |
|
Returns: |
|
list[SRCExamples]: the resulting SRC Examples, contained all the needed information for the feature generation |
|
process, except when the argument simplify is setting to True; |
|
set[str]: all the tag names appeared in the processed dataset, e.g. <div>, <img/>, </p>, etc.. |
|
""" |
|
with open(input_file, "r", encoding='utf-8') as reader: |
|
input_data = json.load(reader)["data"] |
|
|
|
pad_tree_id_seq = [1001] * max_depth |
|
|
|
def is_whitespace(c): |
|
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: |
|
return True |
|
return False |
|
|
|
def html_to_text_list(h): |
|
tag_num, text_list = 0, [] |
|
for element in h.descendants: |
|
if (type(element) == bs4.element.NavigableString) and (element.strip()): |
|
text_list.append(element.strip()) |
|
if type(element) == bs4.element.Tag: |
|
tag_num += 1 |
|
return text_list, tag_num + 2 |
|
|
|
def html_to_text(h): |
|
tag_list = set() |
|
for element in h.descendants: |
|
if type(element) == bs4.element.Tag: |
|
element.attrs = {} |
|
temp = str(element).split() |
|
tag_list.add(temp[0]) |
|
tag_list.add(temp[-1]) |
|
return html_escape(str(h)), tag_list |
|
|
|
def adjust_offset(offset, text): |
|
text_list = text.split() |
|
cnt, adjustment = 0, [] |
|
for t in text_list: |
|
if not t: |
|
continue |
|
if t[0] == '<' and t[-1] == '>': |
|
adjustment.append(offset.index(cnt)) |
|
else: |
|
cnt += 1 |
|
add = 0 |
|
adjustment.append(len(offset)) |
|
for i in range(len(offset)): |
|
while i >= adjustment[add]: |
|
add += 1 |
|
offset[i] += add |
|
return offset |
|
|
|
def e_id_to_t_id(e_id, html): |
|
t_id = 0 |
|
for element in html.descendants: |
|
if type(element) == bs4.element.NavigableString and element.strip(): |
|
t_id += 1 |
|
if type(element) == bs4.element.Tag: |
|
if int(element.attrs['tid']) == e_id: |
|
break |
|
return t_id |
|
|
|
def calc_num_from_raw_text_list(t_id, l): |
|
n_char = 0 |
|
for i in range(t_id): |
|
n_char += len(l[i]) + 1 |
|
return n_char |
|
|
|
def word_to_tag_from_text(tokens, h): |
|
cnt, w_t, path = -1, [], [] |
|
unique_tids = set() |
|
for t in tokens[0:-2]: |
|
if len(t) < 2: |
|
w_t.append(path[-1]) |
|
unique_tids.add(path[-1]) |
|
continue |
|
if t[0] == '<' and t[-2] == '/': |
|
cnt += 1 |
|
w_t.append(cnt) |
|
unique_tids.add(cnt) |
|
continue |
|
if t[0] == '<' and t[1] != '/': |
|
cnt += 1 |
|
path.append(cnt) |
|
w_t.append(path[-1]) |
|
unique_tids.add(path[-1]) |
|
if t[0] == '<' and t[1] == '/': |
|
del path[-1] |
|
w_t.append(cnt + 1) |
|
unique_tids.add(cnt + 1) |
|
w_t.append(cnt + 2) |
|
unique_tids.add(cnt + 2) |
|
assert len(w_t) == len(tokens) |
|
assert len(path) == 0, print(h) |
|
return w_t, unique_tids |
|
|
|
def word_tag_offset(html): |
|
cnt, w_t, t_w, tags, tags_tids = 0, [], [], [], [] |
|
for element in html.descendants: |
|
if type(element) == bs4.element.Tag: |
|
content = ' '.join(list(element.strings)).split() |
|
t_w.append({'start': cnt, 'len': len(content)}) |
|
tags.append('<' + element.name + '>') |
|
tags_tids.append(element['tid']) |
|
elif type(element) == bs4.element.NavigableString and element.strip(): |
|
text = element.split() |
|
tid = element.parent['tid'] |
|
ind = tags_tids.index(tid) |
|
for _ in text: |
|
w_t.append(ind) |
|
cnt += 1 |
|
assert cnt == len(w_t) |
|
w_t.append(len(t_w)) |
|
w_t.append(len(t_w) + 1) |
|
return w_t |
|
|
|
def subtoken_tag_offset(html, s_tok): |
|
w_t = word_tag_offset(html) |
|
s_t = [] |
|
unique_tids = set() |
|
for i in range(len(s_tok)): |
|
s_t.append(w_t[s_tok[i]]) |
|
unique_tids.add(w_t[s_tok[i]]) |
|
return s_t, unique_tids |
|
|
|
def subtoken_tag_offset_plus_eon(html, s_tok, all_doc_tokens): |
|
w_t = word_tag_offset(html) |
|
s_t = [] |
|
unique_tids = set() |
|
offset = 0 |
|
for i in range(len(s_tok)): |
|
if all_doc_tokens[i] not in ('<end-of-node>', tokenizer.sep_token, tokenizer.cls_token): |
|
s_t.append(w_t[s_tok[i] - offset]) |
|
unique_tids.add(w_t[s_tok[i] - offset]) |
|
else: |
|
prev_tid = s_t[-1] |
|
s_t.append(prev_tid) |
|
offset += 1 |
|
return s_t, unique_tids |
|
|
|
def check_visible(path1, path2, attention_width): |
|
i = 0 |
|
j = 0 |
|
dis = 0 |
|
lp1 = len(path1) |
|
lp2 = len(path2) |
|
while i < lp1 and j < lp2 and path1[i] == path2[j]: |
|
i += 1 |
|
j += 1 |
|
|
|
if i < lp1 and j < lp2: |
|
dis += lp1 - i + lp2 - j |
|
else: |
|
if i == lp1: |
|
dis += lp2 - j |
|
else: |
|
dis += lp1 - i |
|
|
|
if dis <= attention_width: |
|
return True |
|
return False |
|
|
|
|
|
def from_tids_to_box(html_fn, unique_tids, json_fn): |
|
sorted_ids = sorted(unique_tids) |
|
f = open(json_fn, 'r') |
|
data = json.load(f) |
|
orig_width, orig_height = data['2']['rect']['width'], data['2']['rect']['height'] |
|
orig_x, orig_y = data['2']['rect']['x'], data['2']['rect']['y'] |
|
|
|
return_dict = {} |
|
for id in sorted_ids: |
|
if str(id) in data: |
|
x, y, width, height = data[str(id)]['rect']['x'], data[str(id)]['rect']['y'], data[str(id)]['rect']['width'], data[str(id)]['rect']['height'] |
|
resize_x = (x - orig_x) * 1000 // orig_width |
|
resize_y = (y - orig_y) * 1000 // orig_height |
|
|
|
resize_width = width * 1000 // orig_width |
|
resize_height = height * 1000 // orig_height |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if resize_x < 0 or resize_y < 0 or resize_width < 0 or resize_height < 0: |
|
return_dict[id] = [0, 0, 0, 0] |
|
else: |
|
return_dict[id] = [int(resize_x), int(resize_y), int(resize_x+resize_width), int(resize_y+resize_height)] |
|
else: |
|
return_dict[id] = [0,0,0,0] |
|
|
|
return return_dict |
|
|
|
def get_visible_matrix(unique_tids, tree_id_map, attention_width): |
|
if attention_width is None: |
|
return None |
|
unique_tids_list = list(unique_tids) |
|
visible_matrix = collections.defaultdict(list) |
|
for i in range(len(unique_tids_list)): |
|
if tree_id_map[unique_tids_list[i]] == pad_tree_id_seq: |
|
visible_matrix[unique_tids_list[i]] = list() |
|
continue |
|
visible_matrix[unique_tids_list[i]].append(unique_tids_list[i]) |
|
for j in range(i + 1, len(unique_tids_list)): |
|
if check_visible(tree_id_map[unique_tids_list[i]], tree_id_map[unique_tids_list[j]], attention_width): |
|
visible_matrix[unique_tids_list[i]].append(unique_tids_list[j]) |
|
visible_matrix[unique_tids_list[j]].append(unique_tids_list[i]) |
|
return visible_matrix |
|
|
|
examples = [] |
|
all_tag_list = set() |
|
total_num = sum([len(entry["websites"]) for entry in input_data]) |
|
with tqdm(total=total_num, desc="Converting websites to examples") as t: |
|
for entry in input_data: |
|
|
|
domain = entry["domain"] |
|
for website in entry["websites"]: |
|
|
|
|
|
|
|
page_id = website["page_id"] |
|
|
|
curr_dir = osp.join(root_dir, domain, page_id[0:2], 'processed_data') |
|
html_fn = osp.join(curr_dir, page_id + '.html') |
|
json_fn = osp.join(curr_dir, page_id + '.json') |
|
|
|
|
|
|
|
html_file = open(html_fn).read() |
|
html_code = bs(html_file, "html.parser") |
|
raw_text_list, tag_num = html_to_text_list(html_code) |
|
|
|
|
|
|
|
|
|
doc_tokens = [] |
|
char_to_word_offset = [] |
|
|
|
|
|
|
|
|
|
if split_flag in ["y-eon", "y-sep", "y-cls"]: |
|
prev_is_whitespace = True |
|
for i, doc_string in enumerate(raw_text_list): |
|
for c in doc_string: |
|
if is_whitespace(c): |
|
prev_is_whitespace = True |
|
else: |
|
if prev_is_whitespace: |
|
doc_tokens.append(c) |
|
else: |
|
doc_tokens[-1] += c |
|
prev_is_whitespace = False |
|
char_to_word_offset.append(len(doc_tokens) - 1) |
|
|
|
if i < len(raw_text_list) - 1: |
|
prev_is_whitespace = True |
|
char_to_word_offset.append(len(doc_tokens) - 1) |
|
|
|
if split_flag == "y-eon": |
|
doc_tokens.append('<end-of-node>') |
|
elif split_flag == "y-sep": |
|
doc_tokens.append(tokenizer.sep_token) |
|
elif split_flag == "y-cls": |
|
doc_tokens.append(tokenizer.cls_token) |
|
else: |
|
raise ValueError("Split flag should be `y-eon` or `y-sep` or `y-cls`") |
|
prev_is_whitespace = True |
|
|
|
elif split_flag =="n-eon" or split_flag == "y-hplm": |
|
page_text = ' '.join(raw_text_list) |
|
prev_is_whitespace = True |
|
for c in page_text: |
|
if is_whitespace(c): |
|
prev_is_whitespace = True |
|
else: |
|
if prev_is_whitespace: |
|
doc_tokens.append(c) |
|
else: |
|
doc_tokens[-1] += c |
|
prev_is_whitespace = False |
|
char_to_word_offset.append(len(doc_tokens) - 1) |
|
|
|
|
|
doc_tokens.append('no') |
|
char_to_word_offset.append(len(doc_tokens) - 1) |
|
doc_tokens.append('yes') |
|
char_to_word_offset.append(len(doc_tokens) - 1) |
|
|
|
if split_flag == "y-hplm": |
|
real_text, tag_list = html_to_text(bs(html_file)) |
|
all_tag_list = all_tag_list | tag_list |
|
char_to_word_offset = adjust_offset(char_to_word_offset, real_text) |
|
doc_tokens = real_text.split() |
|
doc_tokens.append('no') |
|
doc_tokens.append('yes') |
|
doc_tokens = [i for i in doc_tokens if i] |
|
|
|
else: |
|
tag_list = [] |
|
|
|
assert len(doc_tokens) == char_to_word_offset[-1] + 1, (len(doc_tokens), char_to_word_offset[-1]) |
|
|
|
if simplify: |
|
for qa in website["qas"]: |
|
qas_id = qa["id"] |
|
example = SRCExample(doc_tokens=doc_tokens, qas_id=qas_id, tag_num=tag_num) |
|
examples.append(example) |
|
t.update(1) |
|
else: |
|
|
|
|
|
tok_to_orig_index = [] |
|
orig_to_tok_index = [] |
|
all_doc_tokens = [] |
|
for (i, token) in enumerate(doc_tokens): |
|
orig_to_tok_index.append(len(all_doc_tokens)) |
|
if token in tag_list: |
|
sub_tokens = [token] |
|
else: |
|
sub_tokens = tokenizer.tokenize(token) |
|
for sub_token in sub_tokens: |
|
tok_to_orig_index.append(i) |
|
all_doc_tokens.append(sub_token) |
|
|
|
|
|
if split_flag in ["y-eon", "y-sep", "y-cls"]: |
|
tok_to_tags_index, unique_tids = subtoken_tag_offset_plus_eon(html_code, tok_to_orig_index, |
|
all_doc_tokens) |
|
elif split_flag == "n-eon": |
|
tok_to_tags_index, unique_tids = subtoken_tag_offset(html_code, tok_to_orig_index) |
|
|
|
elif split_flag == "y-hplm": |
|
tok_to_tags_index, unique_tids = word_to_tag_from_text(all_doc_tokens, html_code) |
|
|
|
else: |
|
raise ValueError("Unsupported split_flag!") |
|
|
|
|
|
|
|
xpath_tag_map, xpath_subs_map, tree_id_map = get_xpath_and_treeid4tokens(html_code, unique_tids, |
|
max_depth=max_depth) |
|
|
|
xpath_box = from_tids_to_box(html_fn, unique_tids, json_fn) |
|
|
|
|
|
assert tok_to_tags_index[-1] == tag_num - 1, (tok_to_tags_index[-1], tag_num - 1) |
|
|
|
|
|
visible_matrix = get_visible_matrix(unique_tids, tree_id_map, attention_width=attention_width) |
|
|
|
|
|
for qa in website["qas"]: |
|
qas_id = qa["id"] |
|
question_text = qa["question"] |
|
start_position = None |
|
end_position = None |
|
orig_answer_text = None |
|
|
|
if is_training: |
|
if len(qa["answers"]) != 1: |
|
raise ValueError( |
|
"For training, each question should have exactly 1 answer.") |
|
answer = qa["answers"][0] |
|
orig_answer_text = answer["text"] |
|
if answer["element_id"] == -1: |
|
num_char = len(char_to_word_offset) - 2 |
|
else: |
|
num_char = calc_num_from_raw_text_list(e_id_to_t_id(answer["element_id"], html_code), |
|
raw_text_list) |
|
answer_offset = num_char + answer["answer_start"] |
|
answer_length = len(orig_answer_text) if answer["element_id"] != -1 else 1 |
|
start_position = char_to_word_offset[answer_offset] |
|
end_position = char_to_word_offset[answer_offset + answer_length - 1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
actual_text = " ".join([w for w in doc_tokens[start_position:(end_position + 1)] |
|
if (w[0] != '<' or w[-1] != '>') |
|
and w != "<end-of-node>" |
|
and w != tokenizer.sep_token |
|
and w != tokenizer.cls_token]) |
|
cleaned_answer_text = " ".join(whitespace_tokenize(orig_answer_text)) |
|
if actual_text.find(cleaned_answer_text) == -1: |
|
logging.warning("Could not find answer of question %s: '%s' vs. '%s'", |
|
qa['id'], actual_text, cleaned_answer_text) |
|
continue |
|
|
|
example = SRCExample( |
|
doc_tokens=doc_tokens, |
|
qas_id=qas_id, |
|
tag_num=tag_num, |
|
question_text=question_text, |
|
html_code=html_code, |
|
orig_answer_text=orig_answer_text, |
|
start_position=start_position, |
|
end_position=end_position, |
|
tok_to_orig_index=tok_to_orig_index, |
|
orig_to_tok_index=orig_to_tok_index, |
|
all_doc_tokens=all_doc_tokens, |
|
tok_to_tags_index=tok_to_tags_index, |
|
xpath_tag_map=xpath_tag_map, |
|
xpath_subs_map=xpath_subs_map, |
|
xpath_box=xpath_box, |
|
tree_id_map=tree_id_map, |
|
visible_matrix=visible_matrix |
|
) |
|
|
|
examples.append(example) |
|
|
|
|
|
if args.web_num_features != 0: |
|
if len(examples) >= args.web_num_features: |
|
return examples, all_tag_list |
|
|
|
t.update(1) |
|
return examples, all_tag_list |
|
|
|
|
|
def load_and_cache_examples(args, tokenizer, max_depth=50, evaluate=False, output_examples=False): |
|
r""" |
|
Load and process the raw data. |
|
""" |
|
if args.local_rank not in [-1, 0] and not evaluate: |
|
torch.distributed.barrier() |
|
|
|
|
|
|
|
input_file = args.web_eval_file if evaluate else args.web_train_file |
|
|
|
cached_features_file = os.path.join(args.cache_dir, 'cached_{}_{}_{}_{}_{}_{}'.format( |
|
'dev' if evaluate else 'train', |
|
"markuplm", |
|
str(args.max_seq_length), |
|
str(max_depth), |
|
args.web_num_features, |
|
args.model_type |
|
)) |
|
if not os.path.exists(os.path.dirname(cached_features_file)): |
|
os.makedirs(os.path.dirname(cached_features_file)) |
|
|
|
if os.path.exists(cached_features_file) and not args.overwrite_cache: |
|
print("Loading features from cached file %s", cached_features_file) |
|
features = torch.load(cached_features_file) |
|
if output_examples: |
|
examples, tag_list = read_squad_examples(args, input_file=input_file, |
|
root_dir=args.web_root_dir, |
|
is_training=not evaluate, |
|
tokenizer=tokenizer, |
|
simplify=True, |
|
max_depth=max_depth |
|
) |
|
else: |
|
examples = None |
|
else: |
|
print("Creating features from dataset file at %s", input_file) |
|
|
|
examples, _ = read_squad_examples(args, input_file=input_file, |
|
root_dir=args.web_root_dir, |
|
is_training=not evaluate, |
|
tokenizer=tokenizer, |
|
simplify=False, |
|
max_depth=max_depth) |
|
|
|
features = convert_examples_to_features(examples=examples, |
|
tokenizer=tokenizer, |
|
max_seq_length=args.max_seq_length, |
|
doc_stride=args.doc_stride, |
|
max_query_length=args.max_query_length, |
|
is_training=not evaluate, |
|
cls_token=tokenizer.cls_token, |
|
sep_token=tokenizer.sep_token, |
|
pad_token=tokenizer.pad_token_id, |
|
sequence_a_segment_id=0, |
|
sequence_b_segment_id=0, |
|
max_depth=max_depth) |
|
|
|
if args.local_rank in [-1, 0] and args.web_save_features: |
|
print("Saving features into cached file %s", cached_features_file) |
|
torch.save(features, cached_features_file) |
|
|
|
if args.local_rank == 0 and not evaluate: |
|
torch.distributed.barrier() |
|
|
|
|
|
|
|
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) |
|
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) |
|
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) |
|
|
|
all_xpath_tags_seq = torch.tensor([f.xpath_tags_seq for f in features], dtype=torch.long) |
|
all_xpath_subs_seq = torch.tensor([f.xpath_subs_seq for f in features], dtype=torch.long) |
|
all_xpath_box_seq = torch.tensor([f.xpath_box_seq for f in features], dtype=torch.long) |
|
|
|
|
|
if evaluate: |
|
all_feature_index = torch.arange(all_input_ids.size(0), dtype=torch.long) |
|
dataset = StrucDataset(all_input_ids, all_input_mask, all_segment_ids, all_feature_index, |
|
all_xpath_tags_seq, all_xpath_subs_seq, all_xpath_box_seq) |
|
else: |
|
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long) |
|
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long) |
|
dataset = StrucDataset(all_input_ids, all_input_mask, all_segment_ids, |
|
all_xpath_tags_seq, all_xpath_subs_seq, |
|
all_start_positions, all_end_positions, all_xpath_box_seq) |
|
|
|
if output_examples: |
|
dataset = (dataset, examples, features) |
|
return dataset |
|
|
|
|
|
|
|
|
|
def convert_examples_to_features(examples, tokenizer, max_seq_length, doc_stride, max_query_length, is_training, |
|
cls_token='[CLS]', sep_token='[SEP]', pad_token=0, |
|
sequence_a_segment_id=0, sequence_b_segment_id=1, |
|
cls_token_segment_id=0, pad_token_segment_id=0, |
|
mask_padding_with_zero=True, max_depth=50): |
|
r""" |
|
Converting the SRC Examples further into the features for all the input doc spans. |
|
|
|
Arguments: |
|
examples (list[SRCExample]): the list of SRC Examples to process. |
|
tokenizer (Tokenizer): the tokenizer for PLM in use. |
|
max_seq_length (int): the max length of the total sub-token sequence, including the question, cls token, sep |
|
tokens, and documents; if the length of the input is bigger than max_seq_length, the input |
|
will be cut into several doc spans. |
|
doc_stride (int): the stride length when the input is cut into several doc spans. |
|
max_query_length (int): the max length of the sub-token sequence of the questions; the question will be truncate |
|
if it is longer than max_query_length. |
|
is_training (bool): True if processing the training set, else False. |
|
cls_token (str): the cls token in use, default is '[CLS]'. |
|
sep_token (str): the sep token in use, default is '[SEP]'. |
|
pad_token (int): the id of the padding token in use when the total sub-token length is smaller that |
|
max_seq_length, default is 0 which corresponding to the '[PAD]' token. |
|
sequence_a_segment_id: the segment id for the first sequence (the question), default is 0. |
|
sequence_b_segment_id: the segment id for the second sequence (the html file), default is 1. |
|
cls_token_segment_id: the segment id for the cls token, default is 0. |
|
pad_token_segment_id: the segment id for the padding tokens, default is 0. |
|
mask_padding_with_zero: determine the pattern of the returned input mask; 0 for padding tokens and 1 for others |
|
when True, and vice versa. |
|
Returns: |
|
list[InputFeatures]: the resulting input features for all the input doc spans |
|
""" |
|
|
|
pad_x_tag_seq = [216] * max_depth |
|
pad_x_subs_seq = [1001] * max_depth |
|
pad_x_box = [0,0,0,0] |
|
pad_tree_id_seq = [1001] * max_depth |
|
|
|
unique_id = 1000000000 |
|
features = [] |
|
for (example_index, example) in enumerate(tqdm(examples, desc="Converting examples to features")): |
|
|
|
xpath_tag_map = example.xpath_tag_map |
|
xpath_subs_map = example.xpath_subs_map |
|
xpath_box = example.xpath_box |
|
tree_id_map = example.tree_id_map |
|
visible_matrix = example.visible_matrix |
|
|
|
query_tokens = tokenizer.tokenize(example.question_text) |
|
if len(query_tokens) > max_query_length: |
|
query_tokens = query_tokens[0:max_query_length] |
|
|
|
tok_start_position = None |
|
tok_end_position = None |
|
if is_training: |
|
tok_start_position = example.orig_to_tok_index[example.start_position] |
|
if example.end_position < len(example.doc_tokens) - 1: |
|
tok_end_position = example.orig_to_tok_index[example.end_position + 1] - 1 |
|
else: |
|
tok_end_position = len(example.all_doc_tokens) - 1 |
|
(tok_start_position, tok_end_position) = _improve_answer_span( |
|
example.all_doc_tokens, tok_start_position, tok_end_position, tokenizer, |
|
example.orig_answer_text) |
|
|
|
|
|
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 |
|
|
|
|
|
|
|
|
|
_DocSpan = collections.namedtuple( |
|
"DocSpan", ["start", "length"]) |
|
doc_spans = [] |
|
start_offset = 0 |
|
while start_offset < len(example.all_doc_tokens): |
|
length = len(example.all_doc_tokens) - start_offset |
|
if length > max_tokens_for_doc: |
|
length = max_tokens_for_doc |
|
doc_spans.append(_DocSpan(start=start_offset, length=length)) |
|
if start_offset + length == len(example.all_doc_tokens): |
|
break |
|
start_offset += min(length, doc_stride) |
|
|
|
for (doc_span_index, doc_span) in enumerate(doc_spans): |
|
tokens = [] |
|
token_to_orig_map = {} |
|
token_is_max_context = {} |
|
segment_ids = [] |
|
token_to_tag_index = [] |
|
|
|
|
|
tokens.append(cls_token) |
|
segment_ids.append(cls_token_segment_id) |
|
token_to_tag_index.append(example.tag_num) |
|
|
|
|
|
tokens += query_tokens |
|
segment_ids += [sequence_a_segment_id] * len(query_tokens) |
|
token_to_tag_index += [example.tag_num] * len(query_tokens) |
|
|
|
|
|
tokens.append(sep_token) |
|
segment_ids.append(sequence_a_segment_id) |
|
token_to_tag_index.append(example.tag_num) |
|
|
|
|
|
for i in range(doc_span.length): |
|
split_token_index = doc_span.start + i |
|
token_to_orig_map[len(tokens)] = example.tok_to_orig_index[split_token_index] |
|
token_to_tag_index.append(example.tok_to_tags_index[split_token_index]) |
|
|
|
is_max_context = _check_is_max_context(doc_spans, doc_span_index, |
|
split_token_index) |
|
token_is_max_context[len(tokens)] = is_max_context |
|
tokens.append(example.all_doc_tokens[split_token_index]) |
|
segment_ids.append(sequence_b_segment_id) |
|
paragraph_len = doc_span.length |
|
|
|
|
|
tokens.append(sep_token) |
|
segment_ids.append(sequence_b_segment_id) |
|
token_to_tag_index.append(example.tag_num) |
|
|
|
input_ids = tokenizer.convert_tokens_to_ids(tokens) |
|
|
|
|
|
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) |
|
|
|
|
|
while len(input_ids) < max_seq_length: |
|
input_ids.append(pad_token) |
|
input_mask.append(0 if mask_padding_with_zero else 1) |
|
segment_ids.append(pad_token_segment_id) |
|
token_to_tag_index.append(example.tag_num) |
|
|
|
assert len(input_ids) == max_seq_length |
|
assert len(input_mask) == max_seq_length |
|
assert len(segment_ids) == max_seq_length |
|
assert len(token_to_tag_index) == max_seq_length |
|
|
|
span_is_impossible = False |
|
start_position = None |
|
end_position = None |
|
if is_training: |
|
|
|
|
|
doc_start = doc_span.start |
|
doc_end = doc_span.start + doc_span.length - 1 |
|
out_of_span = False |
|
if not (tok_start_position >= doc_start and |
|
tok_end_position <= doc_end): |
|
out_of_span = True |
|
if out_of_span: |
|
span_is_impossible = True |
|
start_position = 0 |
|
end_position = 0 |
|
else: |
|
doc_offset = len(query_tokens) + 2 |
|
start_position = tok_start_position - doc_start + doc_offset |
|
end_position = tok_end_position - doc_start + doc_offset |
|
''' |
|
if 10 < example_index < 20: |
|
print("*** Example ***") |
|
#print("page_id: %s" % (example.qas_id[:-5])) |
|
#print("token_to_tag_index :%s" % token_to_tag_index) |
|
#print(len(token_to_tag_index)) |
|
#print("unique_id: %s" % (unique_id)) |
|
#print("example_index: %s" % (example_index)) |
|
#print("doc_span_index: %s" % (doc_span_index)) |
|
# print("tokens: %s" % " ".join(tokens)) |
|
|
|
print("tokens: %s" % " ".join([ |
|
"%d:%s" % (x, y) for (x, y) in enumerate(tokens) |
|
])) |
|
|
|
#print("token_to_orig_map: %s" % " ".join([ |
|
# "%d:%d" % (x, y) for (x, y) in token_to_orig_map.items()])) |
|
#print(len(token_to_orig_map)) |
|
# print("token_is_max_context: %s" % " ".join([ |
|
# "%d:%s" % (x, y) for (x, y) in token_is_max_context.items() |
|
# ])) |
|
#print(len(token_is_max_context)) |
|
#print("input_ids: %s" % " ".join([str(x) for x in input_ids])) |
|
#print(len(input_ids)) |
|
#print( |
|
# "input_mask: %s" % " ".join([str(x) for x in input_mask])) |
|
#print(len(input_mask)) |
|
#print( |
|
# "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) |
|
#print(len(segment_ids)) |
|
print(f"original answer: {example.orig_answer_text}") |
|
if is_training and span_is_impossible: |
|
print("impossible example") |
|
if is_training and not span_is_impossible: |
|
answer_text = " ".join(tokens[start_position:(end_position + 1)]) |
|
print("start_position: %d" % (start_position)) |
|
print("end_position: %d" % (end_position)) |
|
print( |
|
"answer: %s" % (answer_text)) |
|
''' |
|
|
|
|
|
|
|
|
|
|
|
xpath_tags_seq = [xpath_tag_map.get(tid, pad_x_tag_seq) for tid in token_to_tag_index] |
|
xpath_subs_seq = [xpath_subs_map.get(tid, pad_x_subs_seq) for tid in token_to_tag_index] |
|
xpath_box_seq = [xpath_box.get(tid, pad_x_box) for tid in token_to_tag_index] |
|
|
|
|
|
|
|
|
|
if visible_matrix is not None: |
|
extended_attention_mask = [] |
|
for tid in token_to_tag_index: |
|
if tid == example.tag_num: |
|
extended_attention_mask.append(input_mask) |
|
else: |
|
visible_tids = visible_matrix[tid] |
|
if len(visible_tids) == 0: |
|
extended_attention_mask.append(input_mask) |
|
continue |
|
visible_per_token = [] |
|
for i, tid in enumerate(token_to_tag_index): |
|
if tid == example.tag_num and input_mask[i] == (1 if mask_padding_with_zero else 0): |
|
visible_per_token.append(1 if mask_padding_with_zero else 0) |
|
elif tid in visible_tids: |
|
visible_per_token.append(1 if mask_padding_with_zero else 0) |
|
else: |
|
visible_per_token.append(0 if mask_padding_with_zero else 1) |
|
extended_attention_mask.append(visible_per_token) |
|
else: |
|
extended_attention_mask = None |
|
|
|
features.append( |
|
InputFeatures( |
|
unique_id=unique_id, |
|
example_index=example_index, |
|
page_id=example.qas_id[:-5], |
|
doc_span_index=doc_span_index, |
|
tokens=tokens, |
|
token_to_orig_map=token_to_orig_map, |
|
token_is_max_context=token_is_max_context, |
|
input_ids=input_ids, |
|
input_mask=input_mask, |
|
segment_ids=segment_ids, |
|
paragraph_len=paragraph_len, |
|
start_position=start_position, |
|
end_position=end_position, |
|
token_to_tag_index=token_to_tag_index, |
|
is_impossible=span_is_impossible, |
|
xpath_tags_seq=xpath_tags_seq, |
|
xpath_subs_seq=xpath_subs_seq, |
|
xpath_box_seq=xpath_box_seq, |
|
extended_attention_mask=extended_attention_mask, |
|
)) |
|
unique_id += 1 |
|
|
|
return features |
|
|
|
def get_websrc_dataset(args, tokenizer, evaluate=False, output_examples=False): |
|
if not evaluate: |
|
websrc_dataset = load_and_cache_examples(args, tokenizer, evaluate=evaluate, output_examples=False) |
|
return websrc_dataset |
|
else: |
|
dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=evaluate, output_examples=True) |
|
return dataset, examples, features |
|
|