File size: 6,476 Bytes
77f10a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.t5
import os
import re
import torch
import logging

from tokenizers import Tokenizer
from .ace_text_cleaners import multilingual_cleaners, japanese_to_romaji

SUPPORT_LANGUAGES = {
    "en": 259, "de": 260, "fr": 262, "es": 284, "it": 285,
    "pt": 286, "pl": 294, "tr": 295, "ru": 267, "cs": 293,
    "nl": 297, "ar": 5022, "zh": 5023, "ja": 5412, "hu": 5753,
    "ko": 6152, "hi": 6680
}

structure_pattern = re.compile(r"\[.*?\]")

DEFAULT_VOCAB_FILE = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")


class VoiceBpeTokenizer:
    def __init__(self, vocab_file=DEFAULT_VOCAB_FILE):
        self.tokenizer = None
        if vocab_file is not None:
            self.tokenizer = Tokenizer.from_file(vocab_file)

    def preprocess_text(self, txt, lang):
        txt = multilingual_cleaners(txt, lang)
        return txt

    def encode(self, txt, lang='en'):
        # lang = lang.split("-")[0]  # remove the region
        # self.check_input_length(txt, lang)
        txt = self.preprocess_text(txt, lang)
        lang = "zh-cn" if lang == "zh" else lang
        txt = f"[{lang}]{txt}"
        txt = txt.replace(" ", "[SPACE]")
        return self.tokenizer.encode(txt).ids

    def get_lang(self, line):
        if line.startswith("[") and line[3:4] == ']':
            lang = line[1:3].lower()
            if lang in SUPPORT_LANGUAGES:
                return lang, line[4:]
        return "en", line

    def __call__(self, string):
        lines = string.split("\n")
        lyric_token_idx = [261]
        for line in lines:
            line = line.strip()
            if not line:
                lyric_token_idx += [2]
                continue

            lang, line = self.get_lang(line)

            if lang not in SUPPORT_LANGUAGES:
                lang = "en"
            if "zh" in lang:
                lang = "zh"
            if "spa" in lang:
                lang = "es"

            try:
                line_out = japanese_to_romaji(line)
                if line_out != line:
                    lang = "ja"
                line = line_out
            except:
                pass

            try:
                if structure_pattern.match(line):
                    token_idx = self.encode(line, "en")
                else:
                    token_idx = self.encode(line, lang)
                lyric_token_idx = lyric_token_idx + token_idx + [2]
            except Exception as e:
                logging.warning("tokenize error {} for line {} major_language {}".format(e, line, lang))
        return {"input_ids": lyric_token_idx}

    @staticmethod
    def from_pretrained(path, **kwargs):
        return VoiceBpeTokenizer(path, **kwargs)

    def get_vocab(self):
        return {}


class UMT5BaseModel(sd1_clip.SDClipModel):
    def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
        textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "umt5_config_base.json")
        super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=False, model_options=model_options)

class UMT5BaseTokenizer(sd1_clip.SDTokenizer):
    def __init__(self, embedding_directory=None, tokenizer_data={}):
        tokenizer = tokenizer_data.get("spiece_model", None)
        super().__init__(tokenizer, pad_with_end=False, embedding_size=768, embedding_key='umt5base', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=0, tokenizer_data=tokenizer_data)

    def state_dict(self):
        return {"spiece_model": self.tokenizer.serialize_model()}

class LyricsTokenizer(sd1_clip.SDTokenizer):
    def __init__(self, embedding_directory=None, tokenizer_data={}):
        tokenizer = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
        super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='lyrics', tokenizer_class=VoiceBpeTokenizer, has_start_token=True, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=2, has_end_token=False, tokenizer_data=tokenizer_data)

class AceT5Tokenizer:
    def __init__(self, embedding_directory=None, tokenizer_data={}):
        self.voicebpe = LyricsTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
        self.umt5base = UMT5BaseTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)

    def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
        out = {}
        out["lyrics"] = self.voicebpe.tokenize_with_weights(kwargs.get("lyrics", ""), return_word_ids, **kwargs)
        out["umt5base"] = self.umt5base.tokenize_with_weights(text, return_word_ids, **kwargs)
        return out

    def untokenize(self, token_weight_pair):
        return self.umt5base.untokenize(token_weight_pair)

    def state_dict(self):
        return self.umt5base.state_dict()

class AceT5Model(torch.nn.Module):
    def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
        super().__init__()
        self.umt5base = UMT5BaseModel(device=device, dtype=dtype, model_options=model_options)
        self.dtypes = set()
        if dtype is not None:
            self.dtypes.add(dtype)

    def set_clip_options(self, options):
        self.umt5base.set_clip_options(options)

    def reset_clip_options(self):
        self.umt5base.reset_clip_options()

    def encode_token_weights(self, token_weight_pairs):
        token_weight_pairs_umt5base = token_weight_pairs["umt5base"]
        token_weight_pairs_lyrics = token_weight_pairs["lyrics"]

        t5_out, t5_pooled = self.umt5base.encode_token_weights(token_weight_pairs_umt5base)

        lyrics_embeds = torch.tensor(list(map(lambda a: a[0], token_weight_pairs_lyrics[0]))).unsqueeze(0)
        return t5_out, None, {"conditioning_lyrics": lyrics_embeds}

    def load_sd(self, sd):
        return self.umt5base.load_sd(sd)