File size: 3,332 Bytes
75a65be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
from transformers import DataCollatorForTokenClassification
from transformers.data.data_collator import pad_without_fast_tokenizer_warning


class ExtendedEmbeddingsDataCollatorForTokenClassification(DataCollatorForTokenClassification):
    """
    A data collator for token classification tasks with extended embeddings.

    This data collator extends the functionality of the `DataCollatorForTokenClassification` class
    by adding support for additional features such as `per`, `org`, and `loc`.

    Part of the code copied from: transformers.data.data_collator.DataCollatorForTokenClassification
    """

    def torch_call(self, features):
        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
        per = [feature["per"] for feature in features] if "per" in features[0].keys() else None
        org = [feature["org"] for feature in features] if "org" in features[0].keys() else None
        loc = [feature["loc"] for feature in features] if "loc" in features[0].keys() else None

        no_labels_features = [{k: v for k, v in feature.items() if k not in [label_name, "per", "org", "loc"]} for feature in features]

        batch = pad_without_fast_tokenizer_warning(
            self.tokenizer,
            no_labels_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        if labels is None:
            return batch

        sequence_length = batch["input_ids"].shape[1]
        padding_side = self.tokenizer.padding_side

        def to_list(tensor_or_iterable):
            if isinstance(tensor_or_iterable, torch.Tensor):
                return tensor_or_iterable.tolist()
            return list(tensor_or_iterable)
        
        if padding_side == "right":
            batch[label_name] = [
                to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
            ] 
            batch["per"] = [
                to_list(p) + [0] * (sequence_length - len(p)) for p in per
            ] 
            batch["org"] = [
                to_list(o) + [0] * (sequence_length - len(o)) for o in org
            ] 
            batch["loc"] = [
                to_list(l) + [0] * (sequence_length - len(l)) for l in loc
            ] 
        else:
            batch[label_name] = [
                [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
            ]
            batch["per"] = [
                [0] * (sequence_length - len(p)) + self.to_list(p) for p in per
            ]
            batch["org"] = [
                [0] * (sequence_length - len(o)) + self.to_list(o) for o in org
            ]
            batch["loc"] = [
                [0] * (sequence_length - len(l)) + self.to_list(l) for l in loc
            ]

        batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
        batch["per"] = torch.tensor(batch["per"], dtype=torch.int64)
        batch["org"] = torch.tensor(batch["org"], dtype=torch.int64)
        batch["loc"] = torch.tensor(batch["loc"], dtype=torch.int64)
        return batch