|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Named entity recognition fine-tuning: utilities to work with CoNLL-2003 task. """ |
|
|
|
import logging |
|
import os |
|
from tqdm import * |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class InputExample(object): |
|
"""A single training/test example for token classification.""" |
|
|
|
def __init__(self, guid, words, labels=None): |
|
"""Constructs a InputExample. |
|
|
|
Args: |
|
guid: Unique id for the example. |
|
words: list. The words of the sequence. |
|
labels: (Optional) list. The labels for each word of the sequence. This should be |
|
specified for train and dev examples, but not for test examples. |
|
""" |
|
self.guid = guid |
|
self.words = words |
|
self.labels = labels |
|
|
|
|
|
class InputFeatures(object): |
|
"""A single set of features of data.""" |
|
|
|
def __init__(self, input_ids, input_mask, segment_ids, label_ids = None): |
|
self.input_ids = input_ids |
|
self.input_mask = input_mask |
|
self.segment_ids = segment_ids |
|
self.label_ids = label_ids |
|
|
|
|
|
def read_examples_from_file(data_dir, mode): |
|
file_path = os.path.join(data_dir, "{}.txt".format(mode)) |
|
guid_index = 1 |
|
examples = [] |
|
with open(file_path, encoding="utf-8") as f: |
|
for line in f.readlines(): |
|
line = line.strip().split("\t") |
|
words = line[0].split() |
|
labels = line[1].split() |
|
assert len(words) == len(labels) |
|
guid_index +=1 |
|
examples.append(InputExample(guid=guid_index, words=words, labels=labels)) |
|
|
|
return examples |
|
|
|
|
|
def convert_examples_to_features( |
|
examples, |
|
label_list, |
|
max_seq_length, |
|
tokenizer, |
|
cls_token_at_end=False, |
|
cls_token="[CLS]", |
|
cls_token_segment_id=1, |
|
sep_token="[SEP]", |
|
sep_token_extra=False, |
|
pad_on_left=False, |
|
pad_token=0, |
|
pad_token_segment_id=0, |
|
pad_token_label_id=-100, |
|
sequence_a_segment_id=0, |
|
mask_padding_with_zero=True, |
|
mode="train", |
|
|
|
): |
|
""" Loads a data file into a list of `InputBatch`s |
|
`cls_token_at_end` define the location of the CLS token: |
|
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP] |
|
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] |
|
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) |
|
""" |
|
|
|
label_map = {label: i for i, label in enumerate(label_list)} |
|
|
|
features = [] |
|
for (ex_index, example) in enumerate(tqdm(examples)): |
|
if ex_index % 10000 == 0: |
|
logger.info("Writing example %d of %d", ex_index, len(examples)) |
|
|
|
tokens = [] |
|
label_ids = [] |
|
for word, label in zip(example.words, example.labels): |
|
word_tokens = tokenizer.tokenize(word) |
|
tokens.extend(word_tokens) |
|
|
|
label_ids.extend([label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1)) |
|
|
|
|
|
special_tokens_count = 3 if sep_token_extra else 2 |
|
if len(tokens) > max_seq_length - special_tokens_count: |
|
tokens = tokens[: (max_seq_length - special_tokens_count)] |
|
label_ids = label_ids[: (max_seq_length - special_tokens_count)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokens += [sep_token] |
|
label_ids += [pad_token_label_id] |
|
if sep_token_extra: |
|
|
|
tokens += [sep_token] |
|
label_ids += [pad_token_label_id] |
|
segment_ids = [sequence_a_segment_id] * len(tokens) |
|
|
|
if cls_token_at_end: |
|
tokens += [cls_token] |
|
label_ids += [pad_token_label_id] |
|
segment_ids += [cls_token_segment_id] |
|
else: |
|
tokens = [cls_token] + tokens |
|
label_ids = [pad_token_label_id] + label_ids |
|
segment_ids = [cls_token_segment_id] + segment_ids |
|
|
|
input_ids = tokenizer.convert_tokens_to_ids(tokens) |
|
|
|
|
|
|
|
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) |
|
|
|
|
|
padding_length = max_seq_length - len(input_ids) |
|
if pad_on_left: |
|
input_ids = ([pad_token] * padding_length) + input_ids |
|
input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask |
|
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids |
|
label_ids = ([pad_token_label_id] * padding_length) + label_ids |
|
else: |
|
input_ids += [pad_token] * padding_length |
|
input_mask += [0 if mask_padding_with_zero else 1] * padding_length |
|
segment_ids += [pad_token_segment_id] * padding_length |
|
label_ids += [pad_token_label_id] * padding_length |
|
|
|
assert len(input_ids) == max_seq_length |
|
assert len(input_mask) == max_seq_length |
|
assert len(segment_ids) == max_seq_length |
|
assert len(label_ids) == max_seq_length |
|
|
|
if ex_index < 5: |
|
logger.info("*** Example ***") |
|
logger.info("guid: %s", example.guid) |
|
logger.info("tokens: %s", " ".join([str(x) for x in tokens])) |
|
logger.info("input_ids: %s", " ".join([str(x) for x in input_ids])) |
|
logger.info("input_mask: %s", " ".join([str(x) for x in input_mask])) |
|
logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) |
|
logger.info("label_ids: %s", " ".join([str(x) for x in label_ids])) |
|
|
|
features.append( |
|
InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_ids=label_ids) |
|
) |
|
return features |
|
|
|
|
|
def get_labels(path): |
|
if path: |
|
with open(path, "r") as f: |
|
labels = f.read().splitlines() |
|
if "O" not in labels: |
|
labels = ["O"] + labels |
|
return labels |
|
else: |
|
return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"] |