kemuriririn lj1995 commited on
Commit
5ff38ad
·
verified ·
0 Parent(s):

Super-squash branch 'main' using huggingface_hub

Browse files

Co-authored-by: lj1995 <[email protected]>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. AR/__init__.py +0 -0
  3. AR/data/__init__.py +0 -0
  4. AR/data/bucket_sampler.py +163 -0
  5. AR/data/data_module.py +76 -0
  6. AR/data/dataset.py +323 -0
  7. AR/models/__init__.py +0 -0
  8. AR/models/t2s_lightning_module.py +141 -0
  9. AR/models/t2s_lightning_module_onnx.py +107 -0
  10. AR/models/t2s_model.py +586 -0
  11. AR/models/t2s_model_onnx.py +338 -0
  12. AR/models/utils.py +229 -0
  13. AR/modules/__init__.py +0 -0
  14. AR/modules/activation.py +428 -0
  15. AR/modules/activation_onnx.py +178 -0
  16. AR/modules/embedding.py +81 -0
  17. AR/modules/embedding_onnx.py +63 -0
  18. AR/modules/lr_schedulers.py +83 -0
  19. AR/modules/optim.py +622 -0
  20. AR/modules/patched_mha_with_cache.py +465 -0
  21. AR/modules/patched_mha_with_cache_onnx.py +92 -0
  22. AR/modules/scaling.py +335 -0
  23. AR/modules/transformer.py +378 -0
  24. AR/modules/transformer_onnx.py +292 -0
  25. AR/text_processing/__init__.py +0 -0
  26. AR/text_processing/phonemizer.py +79 -0
  27. AR/text_processing/symbols.py +10 -0
  28. AR/utils/__init__.py +37 -0
  29. AR/utils/initialize.py +38 -0
  30. AR/utils/io.py +34 -0
  31. README.md +16 -0
  32. configs/s1.yaml +31 -0
  33. configs/s1big.yaml +31 -0
  34. configs/s1big2.yaml +31 -0
  35. configs/s1longer-v2.yaml +31 -0
  36. configs/s1longer.yaml +31 -0
  37. configs/s1mq.yaml +77 -0
  38. configs/s2.json +90 -0
  39. configs/train.yaml +32 -0
  40. download.py +5 -0
  41. feature_extractor/__init__.py +6 -0
  42. feature_extractor/cnhubert.py +109 -0
  43. feature_extractor/whisper_enc.py +25 -0
  44. inference_cli.py +55 -0
  45. inference_gui.py +310 -0
  46. inference_webui.py +678 -0
  47. module/__init__.py +0 -0
  48. module/attentions.py +709 -0
  49. module/attentions_onnx.py +354 -0
  50. module/commons.py +189 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
AR/__init__.py ADDED
File without changes
AR/data/__init__.py ADDED
File without changes
AR/data/bucket_sampler.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/bucket_sampler.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import itertools
4
+ import math
5
+ import random
6
+ from random import shuffle
7
+ from typing import Iterator
8
+ from typing import Optional
9
+ from typing import TypeVar
10
+
11
+ import torch
12
+ import torch.distributed as dist
13
+ from torch.utils.data import Dataset
14
+ from torch.utils.data import Sampler
15
+
16
+ __all__ = [
17
+ "DistributedBucketSampler",
18
+ ]
19
+
20
+ T_co = TypeVar("T_co", covariant=True)
21
+
22
+
23
+ class DistributedBucketSampler(Sampler[T_co]):
24
+ r"""
25
+ sort the dataset wrt. input length
26
+ divide samples into buckets
27
+ sort within buckets
28
+ divide buckets into batches
29
+ sort batches
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ dataset: Dataset,
35
+ num_replicas: Optional[int] = None,
36
+ rank: Optional[int] = None,
37
+ shuffle: bool = True,
38
+ seed: int = 0,
39
+ drop_last: bool = False,
40
+ batch_size: int = 32,
41
+ ) -> None:
42
+ if num_replicas is None:
43
+ if not dist.is_available():
44
+ raise RuntimeError("Requires distributed package to be available")
45
+ num_replicas = dist.get_world_size() if torch.cuda.is_available() else 1
46
+ if rank is None:
47
+ if not dist.is_available():
48
+ raise RuntimeError("Requires distributed package to be available")
49
+ rank = dist.get_rank() if torch.cuda.is_available() else 0
50
+ if torch.cuda.is_available():
51
+ torch.cuda.set_device(rank)
52
+ if rank >= num_replicas or rank < 0:
53
+ raise ValueError(
54
+ "Invalid rank {}, rank should be in the interval"
55
+ " [0, {}]".format(rank, num_replicas - 1)
56
+ )
57
+ self.dataset = dataset
58
+ self.num_replicas = num_replicas
59
+ self.rank = rank
60
+ self.epoch = 0
61
+ self.drop_last = drop_last
62
+ # If the dataset length is evenly divisible by # of replicas, then there
63
+ # is no need to drop any data, since the dataset will be split equally.
64
+ if (
65
+ self.drop_last and len(self.dataset) % self.num_replicas != 0
66
+ ): # type: ignore[arg-type]
67
+ # Split to nearest available length that is evenly divisible.
68
+ # This is to ensure each rank receives the same amount of data when
69
+ # using this Sampler.
70
+ self.num_samples = math.ceil(
71
+ (len(self.dataset) - self.num_replicas)
72
+ / self.num_replicas # type: ignore[arg-type]
73
+ )
74
+ else:
75
+ self.num_samples = math.ceil(
76
+ len(self.dataset) / self.num_replicas
77
+ ) # type: ignore[arg-type]
78
+ self.total_size = self.num_samples * self.num_replicas
79
+ self.shuffle = shuffle
80
+ self.seed = seed
81
+ self.batch_size = batch_size
82
+ self.id_with_length = self._get_sample_lengths()
83
+ self.id_buckets = self.make_buckets(bucket_width=2.0)
84
+
85
+ def _get_sample_lengths(self):
86
+ id_with_lengths = []
87
+ for i in range(len(self.dataset)):
88
+ id_with_lengths.append((i, self.dataset.get_sample_length(i)))
89
+ id_with_lengths.sort(key=lambda x: x[1])
90
+ return id_with_lengths
91
+
92
+ def make_buckets(self, bucket_width: float = 2.0):
93
+ buckets = []
94
+ cur = []
95
+ max_sec = bucket_width
96
+ for id, sec in self.id_with_length:
97
+ if sec < max_sec:
98
+ cur.append(id)
99
+ else:
100
+ buckets.append(cur)
101
+ cur = [id]
102
+ max_sec += bucket_width
103
+ if len(cur) > 0:
104
+ buckets.append(cur)
105
+ return buckets
106
+
107
+ def __iter__(self) -> Iterator[T_co]:
108
+ if self.shuffle:
109
+ # deterministically shuffle based on epoch and seed
110
+ g = torch.Generator()
111
+ g.manual_seed(self.seed + self.epoch)
112
+ random.seed(self.epoch + self.seed)
113
+ shuffled_bucket = []
114
+ for buc in self.id_buckets:
115
+ buc_copy = buc.copy()
116
+ shuffle(buc_copy)
117
+ shuffled_bucket.append(buc_copy)
118
+ grouped_batch_size = self.batch_size * self.num_replicas
119
+ shuffled_bucket = list(itertools.chain(*shuffled_bucket))
120
+ n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
121
+ batches = [
122
+ shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size]
123
+ for b in range(n_batch)
124
+ ]
125
+ shuffle(batches)
126
+ indices = list(itertools.chain(*batches))
127
+ else:
128
+ # type: ignore[arg-type]
129
+ indices = list(range(len(self.dataset)))
130
+
131
+ if not self.drop_last:
132
+ # add extra samples to make it evenly divisible
133
+ padding_size = self.total_size - len(indices)
134
+ if padding_size <= len(indices):
135
+ indices += indices[:padding_size]
136
+ else:
137
+ indices += (indices * math.ceil(padding_size / len(indices)))[
138
+ :padding_size
139
+ ]
140
+ else:
141
+ # remove tail of data to make it evenly divisible.
142
+ indices = indices[: self.total_size]
143
+ assert len(indices) == self.total_size
144
+
145
+ # subsample
146
+ indices = indices[self.rank : self.total_size : self.num_replicas]
147
+ assert len(indices) == self.num_samples
148
+
149
+ return iter(indices)
150
+
151
+ def __len__(self) -> int:
152
+ return self.num_samples
153
+
154
+ def set_epoch(self, epoch: int) -> None:
155
+ r"""
156
+ Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
157
+ use a different random ordering for each epoch. Otherwise, the next iteration of this
158
+ sampler will yield the same ordering.
159
+
160
+ Args:
161
+ epoch (int): Epoch number.
162
+ """
163
+ self.epoch = epoch
AR/data/data_module.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ from pytorch_lightning import LightningDataModule
4
+ from AR.data.bucket_sampler import DistributedBucketSampler
5
+ from AR.data.dataset import Text2SemanticDataset
6
+ from torch.utils.data import DataLoader
7
+
8
+
9
+ class Text2SemanticDataModule(LightningDataModule):
10
+ def __init__(
11
+ self,
12
+ config,
13
+ train_semantic_path,
14
+ train_phoneme_path,
15
+ dev_semantic_path=None,
16
+ dev_phoneme_path=None,
17
+ ):
18
+ super().__init__()
19
+ self.config = config
20
+ self.train_semantic_path = train_semantic_path
21
+ self.train_phoneme_path = train_phoneme_path
22
+ self.dev_semantic_path = dev_semantic_path
23
+ self.dev_phoneme_path = dev_phoneme_path
24
+ self.num_workers = self.config["data"]["num_workers"]
25
+
26
+ def prepare_data(self):
27
+ pass
28
+
29
+ def setup(self, stage=None, output_logs=False):
30
+ self._train_dataset = Text2SemanticDataset(
31
+ phoneme_path=self.train_phoneme_path,
32
+ semantic_path=self.train_semantic_path,
33
+ max_sec=self.config["data"]["max_sec"],
34
+ pad_val=self.config["data"]["pad_val"],
35
+ )
36
+ self._dev_dataset = self._train_dataset
37
+ # self._dev_dataset = Text2SemanticDataset(
38
+ # phoneme_path=self.dev_phoneme_path,
39
+ # semantic_path=self.dev_semantic_path,
40
+ # max_sample=self.config['data']['max_eval_sample'],
41
+ # max_sec=self.config['data']['max_sec'],
42
+ # pad_val=self.config['data']['pad_val'])
43
+
44
+ def train_dataloader(self):
45
+ batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"]
46
+ batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存
47
+ sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
48
+ return DataLoader(
49
+ self._train_dataset,
50
+ batch_size=batch_size,
51
+ sampler=sampler,
52
+ collate_fn=self._train_dataset.collate,
53
+ num_workers=self.num_workers,
54
+ persistent_workers=True,
55
+ prefetch_factor=16,
56
+ )
57
+
58
+ def val_dataloader(self):
59
+ return DataLoader(
60
+ self._dev_dataset,
61
+ batch_size=1,
62
+ shuffle=False,
63
+ collate_fn=self._train_dataset.collate,
64
+ num_workers=max(self.num_workers, 12),
65
+ persistent_workers=True,
66
+ prefetch_factor=16,
67
+ )
68
+
69
+ # 这个会使用到嘛?
70
+ def test_dataloader(self):
71
+ return DataLoader(
72
+ self._dev_dataset,
73
+ batch_size=1,
74
+ shuffle=False,
75
+ collate_fn=self._train_dataset.collate,
76
+ )
AR/data/dataset.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/dataset.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import pdb
4
+ import sys
5
+
6
+ # sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
7
+ import traceback, os
8
+ from typing import Dict
9
+ from typing import List
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import torch, json
14
+ from torch.utils.data import DataLoader
15
+ from torch.utils.data import Dataset
16
+ from transformers import AutoTokenizer
17
+
18
+ version = os.environ.get('version',None)
19
+
20
+ from text import cleaned_text_to_sequence
21
+
22
+ # from config import exp_dir
23
+
24
+
25
+ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0):
26
+ seq = sequences[0]
27
+ ndim = seq.ndim
28
+ if axis < 0:
29
+ axis += ndim
30
+ dtype = seq.dtype
31
+ pad_value = dtype.type(pad_value)
32
+ seq_lengths = [seq.shape[axis] for seq in sequences]
33
+ max_length = np.max(seq_lengths)
34
+
35
+ padded_sequences = []
36
+ for seq, length in zip(sequences, seq_lengths):
37
+ padding = (
38
+ [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
39
+ )
40
+ padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
41
+ padded_sequences.append(padded_seq)
42
+ batch = np.stack(padded_sequences)
43
+ return batch
44
+
45
+
46
+ class Text2SemanticDataset(Dataset):
47
+ """dataset class for text tokens to semantic model training."""
48
+
49
+ def __init__(
50
+ self,
51
+ phoneme_path: str,
52
+ semantic_path: str,
53
+ max_sample: int = None,
54
+ max_sec: int = 100,
55
+ pad_val: int = 1024,
56
+ # min value of phoneme/sec
57
+ min_ps_ratio: int = 3,
58
+ # max value of phoneme/sec
59
+ max_ps_ratio: int = 25,
60
+ ) -> None:
61
+ super().__init__()
62
+
63
+ self.semantic_data = pd.read_csv(
64
+ semantic_path, delimiter="\t", encoding="utf-8"
65
+ )
66
+ # get dict
67
+ self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
68
+ self.path3 = "%s/3-bert" % (
69
+ os.path.dirname(phoneme_path)
70
+ ) # "%s/3-bert"%exp_dir#bert_dir
71
+ self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
72
+ assert os.path.exists(self.path2)
73
+ assert os.path.exists(self.path6)
74
+ self.phoneme_data = {}
75
+ with open(self.path2, "r", encoding="utf8") as f:
76
+ lines = f.read().strip("\n").split("\n")
77
+
78
+ for line in lines:
79
+ tmp = line.split("\t")
80
+ if len(tmp) != 4:
81
+ continue
82
+ self.phoneme_data[tmp[0]] = [tmp[1], tmp[2], tmp[3]]
83
+
84
+ # self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item()
85
+ # pad for semantic tokens
86
+ self.PAD: int = pad_val
87
+ # self.hz = 25
88
+ # with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read()
89
+ # data=json.loads(data)["model"]["semantic_frame_rate"]#50hz
90
+ # self.hz=int(data[:-2])#
91
+ self.hz = int(os.environ.get("hz", "25hz")[:-2])
92
+
93
+ # max seconds of semantic token
94
+ self.max_sec = max_sec
95
+ self.min_ps_ratio = min_ps_ratio
96
+ self.max_ps_ratio = max_ps_ratio
97
+
98
+ if max_sample is not None:
99
+ self.semantic_data = self.semantic_data[:max_sample]
100
+
101
+ # {idx: (semantic, phoneme)}
102
+ # semantic list, phoneme list
103
+ self.semantic_phoneme = []
104
+ self.item_names = []
105
+
106
+ self.inited = False
107
+
108
+ if not self.inited:
109
+ # 调用初始化函数
110
+ self.init_batch()
111
+ self.inited = True
112
+ del self.semantic_data
113
+ del self.phoneme_data
114
+ # self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large")
115
+ # self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large")
116
+
117
+ def init_batch(self):
118
+ semantic_data_len = len(self.semantic_data)
119
+ phoneme_data_len = len(self.phoneme_data.keys())
120
+ print("semantic_data_len:", semantic_data_len)
121
+ print("phoneme_data_len:", phoneme_data_len)
122
+ print(self.semantic_data)
123
+ idx = 0
124
+ num_not_in = 0
125
+ num_deleted_bigger = 0
126
+ num_deleted_ps = 0
127
+ for i in range(semantic_data_len):
128
+ # 先依次遍历
129
+ # get str
130
+ item_name = self.semantic_data.iloc[i,0]
131
+ # print(self.phoneme_data)
132
+ try:
133
+ phoneme, word2ph, text = self.phoneme_data[item_name]
134
+ except Exception:
135
+ traceback.print_exc()
136
+ # print(f"{item_name} not in self.phoneme_data !")
137
+ num_not_in += 1
138
+ continue
139
+
140
+ semantic_str = self.semantic_data.iloc[i,1]
141
+ # get token list
142
+ semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
143
+ # (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
144
+ # 过滤掉太长的样本
145
+ if (
146
+ len(semantic_ids) > self.max_sec * self.hz
147
+ ): #########1###根据token个数推测总时长过滤时长60s(config里)#40*25=1k
148
+ num_deleted_bigger += 1
149
+ continue
150
+ # (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理####
151
+ phoneme = phoneme.split(" ")
152
+
153
+ try:
154
+ phoneme_ids = cleaned_text_to_sequence(phoneme, version)
155
+ except:
156
+ traceback.print_exc()
157
+ # print(f"{item_name} not in self.phoneme_data !")
158
+ num_not_in += 1
159
+ continue
160
+ # if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
161
+ if (
162
+ len(phoneme_ids) > self.max_sec * self.hz / 2.5
163
+ ): ###########2:改为恒定限制为semantic/2.5就行
164
+ num_deleted_ps += 1
165
+ continue
166
+ # if len(semantic_ids) > 1000:###########3
167
+ # num_deleted_bigger += 1
168
+ # continue
169
+
170
+ ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
171
+
172
+ if (
173
+ ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
174
+ ): ##########4#3~25#每秒多少个phone
175
+ num_deleted_ps += 1
176
+ # print(item_name)
177
+ continue
178
+
179
+ self.semantic_phoneme.append((semantic_ids, phoneme_ids))
180
+ idx += 1
181
+ self.item_names.append(item_name)
182
+
183
+ min_num = 100 # 20直接不补#30补了也不存ckpt
184
+ leng = len(self.semantic_phoneme)
185
+ if leng < min_num:
186
+ tmp1 = self.semantic_phoneme
187
+ tmp2 = self.item_names
188
+ self.semantic_phoneme = []
189
+ self.item_names = []
190
+ for _ in range(max(2, int(min_num / leng))):
191
+ self.semantic_phoneme += tmp1
192
+ self.item_names += tmp2
193
+ if num_not_in > 0:
194
+ print(f"there are {num_not_in} semantic datas not in phoneme datas")
195
+ if num_deleted_bigger > 0:
196
+ print(
197
+ f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds"
198
+ )
199
+ if num_deleted_ps > 0:
200
+ # 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
201
+ print(
202
+ f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
203
+ )
204
+ """
205
+ there are 31 semantic datas not in phoneme datas
206
+ deleted 34 audios who's duration are bigger than 54 seconds
207
+ deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3
208
+ dataset.__len__(): 366463
209
+
210
+ """
211
+ # 345410 for LibriTTS
212
+ print("dataset.__len__():", self.__len__())
213
+
214
+ def __get_item_names__(self) -> List[str]:
215
+ return self.item_names
216
+
217
+ def __len__(self) -> int:
218
+ return len(self.semantic_phoneme)
219
+
220
+ def __getitem__(self, idx: int) -> Dict:
221
+ semantic_ids, phoneme_ids = self.semantic_phoneme[idx]
222
+ item_name = self.item_names[idx]
223
+ phoneme_ids_len = len(phoneme_ids)
224
+ # semantic tokens target
225
+ semantic_ids_len = len(semantic_ids)
226
+
227
+ flag = 0
228
+ path_bert = "%s/%s.pt" % (self.path3, item_name)
229
+ if os.path.exists(path_bert) == True:
230
+ bert_feature = torch.load(path_bert, map_location="cpu")
231
+ else:
232
+ flag = 1
233
+ if flag == 1:
234
+ # bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
235
+ bert_feature = None
236
+ else:
237
+ assert bert_feature.shape[-1] == len(phoneme_ids)
238
+ return {
239
+ "idx": idx,
240
+ "phoneme_ids": phoneme_ids,
241
+ "phoneme_ids_len": phoneme_ids_len,
242
+ "semantic_ids": semantic_ids,
243
+ "semantic_ids_len": semantic_ids_len,
244
+ "bert_feature": bert_feature,
245
+ }
246
+
247
+ def get_sample_length(self, idx: int):
248
+ semantic_ids = self.semantic_phoneme[idx][0]
249
+ sec = 1.0 * len(semantic_ids) / self.hz
250
+ return sec
251
+
252
+ def collate(self, examples: List[Dict]) -> Dict:
253
+ sample_index: List[int] = []
254
+ phoneme_ids: List[torch.Tensor] = []
255
+ phoneme_ids_lens: List[int] = []
256
+ semantic_ids: List[torch.Tensor] = []
257
+ semantic_ids_lens: List[int] = []
258
+ # return
259
+
260
+ for item in examples:
261
+ sample_index.append(item["idx"])
262
+ phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
263
+ semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64))
264
+ phoneme_ids_lens.append(item["phoneme_ids_len"])
265
+ semantic_ids_lens.append(item["semantic_ids_len"])
266
+
267
+ # pad 0
268
+ phoneme_ids = batch_sequences(phoneme_ids)
269
+ semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD)
270
+
271
+ # # convert each batch to torch.tensor
272
+ phoneme_ids = torch.tensor(phoneme_ids)
273
+ semantic_ids = torch.tensor(semantic_ids)
274
+ phoneme_ids_lens = torch.tensor(phoneme_ids_lens)
275
+ semantic_ids_lens = torch.tensor(semantic_ids_lens)
276
+ bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens))
277
+ bert_padded.zero_()
278
+
279
+ for idx, item in enumerate(examples):
280
+ bert = item["bert_feature"]
281
+ if bert != None:
282
+ bert_padded[idx, :, : bert.shape[-1]] = bert
283
+
284
+ return {
285
+ # List[int]
286
+ "ids": sample_index,
287
+ # torch.Tensor (B, max_phoneme_length)
288
+ "phoneme_ids": phoneme_ids,
289
+ # torch.Tensor (B)
290
+ "phoneme_ids_len": phoneme_ids_lens,
291
+ # torch.Tensor (B, max_semantic_ids_length)
292
+ "semantic_ids": semantic_ids,
293
+ # torch.Tensor (B)
294
+ "semantic_ids_len": semantic_ids_lens,
295
+ # torch.Tensor (B, 1024, max_phoneme_length)
296
+ "bert_feature": bert_padded,
297
+ }
298
+
299
+
300
+ if __name__ == "__main__":
301
+ root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/"
302
+ dataset = Text2SemanticDataset(
303
+ phoneme_path=root_dir + "phoneme_train.npy",
304
+ semantic_path=root_dir + "semantic_train.tsv",
305
+ )
306
+
307
+ batch_size = 12
308
+ dataloader = DataLoader(
309
+ dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False
310
+ )
311
+ for i, batch in enumerate(dataloader):
312
+ if i % 1000 == 0:
313
+ print(i)
314
+ # if i == 0:
315
+ # print('batch["ids"]:', batch["ids"])
316
+ # print('batch["phoneme_ids"]:', batch["phoneme_ids"],
317
+ # batch["phoneme_ids"].shape)
318
+ # print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
319
+ # batch["phoneme_ids_len"].shape)
320
+ # print('batch["semantic_ids"]:', batch["semantic_ids"],
321
+ # batch["semantic_ids"].shape)
322
+ # print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
323
+ # batch["semantic_ids_len"].shape)
AR/models/__init__.py ADDED
File without changes
AR/models/t2s_lightning_module.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import os, sys
4
+
5
+ now_dir = os.getcwd()
6
+ sys.path.append(now_dir)
7
+ from typing import Dict
8
+
9
+ import torch
10
+ from pytorch_lightning import LightningModule
11
+ from AR.models.t2s_model import Text2SemanticDecoder
12
+ from AR.modules.lr_schedulers import WarmupCosineLRSchedule
13
+ from AR.modules.optim import ScaledAdam
14
+
15
+ class Text2SemanticLightningModule(LightningModule):
16
+ def __init__(self, config, output_dir, is_train=True):
17
+ super().__init__()
18
+ self.config = config
19
+ self.top_k = 3
20
+ self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
21
+ pretrained_s1 = config.get("pretrained_s1")
22
+ if pretrained_s1 and is_train:
23
+ # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
24
+ print(
25
+ self.load_state_dict(
26
+ torch.load(pretrained_s1, map_location="cpu")["weight"]
27
+ )
28
+ )
29
+ if is_train:
30
+ self.automatic_optimization = False
31
+ self.save_hyperparameters()
32
+ self.eval_dir = output_dir / "eval"
33
+ self.eval_dir.mkdir(parents=True, exist_ok=True)
34
+
35
+ def training_step(self, batch: Dict, batch_idx: int):
36
+ opt = self.optimizers()
37
+ scheduler = self.lr_schedulers()
38
+ forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old
39
+ loss, acc = forward(
40
+ batch["phoneme_ids"],
41
+ batch["phoneme_ids_len"],
42
+ batch["semantic_ids"],
43
+ batch["semantic_ids_len"],
44
+ batch["bert_feature"],
45
+ )
46
+ self.manual_backward(loss)
47
+ if batch_idx > 0 and batch_idx % 4 == 0:
48
+ opt.step()
49
+ opt.zero_grad()
50
+ scheduler.step()
51
+
52
+ self.log(
53
+ "total_loss",
54
+ loss,
55
+ on_step=True,
56
+ on_epoch=True,
57
+ prog_bar=True,
58
+ sync_dist=True,
59
+ )
60
+ self.log(
61
+ "lr",
62
+ scheduler.get_last_lr()[0],
63
+ on_epoch=True,
64
+ prog_bar=True,
65
+ sync_dist=True,
66
+ )
67
+ self.log(
68
+ f"top_{self.top_k}_acc",
69
+ acc,
70
+ on_step=True,
71
+ on_epoch=True,
72
+ prog_bar=True,
73
+ sync_dist=True,
74
+ )
75
+
76
+ def validation_step(self, batch: Dict, batch_idx: int):
77
+ return
78
+
79
+ # # get loss
80
+ # loss, acc = self.model.forward(
81
+ # batch['phoneme_ids'], batch['phoneme_ids_len'],
82
+ # batch['semantic_ids'], batch['semantic_ids_len'],
83
+ # batch['bert_feature']
84
+ # )
85
+ #
86
+ # self.log(
87
+ # "val_total_loss",
88
+ # loss,
89
+ # on_step=True,
90
+ # on_epoch=True,
91
+ # prog_bar=True,
92
+ # sync_dist=True)
93
+ # self.log(
94
+ # f"val_top_{self.top_k}_acc",
95
+ # acc,
96
+ # on_step=True,
97
+ # on_epoch=True,
98
+ # prog_bar=True,
99
+ # sync_dist=True)
100
+ #
101
+ # # get infer output
102
+ # semantic_len = batch['semantic_ids'].size(1)
103
+ # prompt_len = min(int(semantic_len * 0.5), 150)
104
+ # prompt = batch['semantic_ids'][:, :prompt_len]
105
+ # pred_semantic = self.model.infer(batch['phoneme_ids'],
106
+ # batch['phoneme_ids_len'], prompt,
107
+ # batch['bert_feature']
108
+ # )
109
+ # save_name = f'semantic_toks_{batch_idx}.pt'
110
+ # save_path = os.path.join(self.eval_dir, save_name)
111
+ # torch.save(pred_semantic.detach().cpu(), save_path)
112
+
113
+ def configure_optimizers(self):
114
+ model_parameters = self.model.parameters()
115
+ parameters_names = []
116
+ parameters_names.append(
117
+ [name_param_pair[0] for name_param_pair in self.model.named_parameters()]
118
+ )
119
+ lm_opt = ScaledAdam(
120
+ model_parameters,
121
+ lr=0.01,
122
+ betas=(0.9, 0.95),
123
+ clipping_scale=2.0,
124
+ parameters_names=parameters_names,
125
+ show_dominant_parameters=False,
126
+ clipping_update_period=1000,
127
+ )
128
+
129
+ return {
130
+ "optimizer": lm_opt,
131
+ "lr_scheduler": {
132
+ "scheduler": WarmupCosineLRSchedule(
133
+ lm_opt,
134
+ init_lr=self.config["optimizer"]["lr_init"],
135
+ peak_lr=self.config["optimizer"]["lr"],
136
+ end_lr=self.config["optimizer"]["lr_end"],
137
+ warmup_steps=self.config["optimizer"]["warmup_steps"],
138
+ total_steps=self.config["optimizer"]["decay_steps"],
139
+ )
140
+ },
141
+ }
AR/models/t2s_lightning_module_onnx.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import os, sys
4
+
5
+ now_dir = os.getcwd()
6
+ sys.path.append(now_dir)
7
+ from typing import Dict
8
+
9
+ import torch
10
+ from pytorch_lightning import LightningModule
11
+ from AR.models.t2s_model_onnx import Text2SemanticDecoder
12
+ from AR.modules.lr_schedulers import WarmupCosineLRSchedule
13
+ from AR.modules.optim import ScaledAdam
14
+
15
+
16
+ class Text2SemanticLightningModule(LightningModule):
17
+ def __init__(self, config, output_dir, is_train=True):
18
+ super().__init__()
19
+ self.config = config
20
+ self.top_k = 3
21
+ self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
22
+ pretrained_s1 = config.get("pretrained_s1")
23
+ if pretrained_s1 and is_train:
24
+ # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
25
+ print(
26
+ self.load_state_dict(
27
+ torch.load(pretrained_s1, map_location="cpu")["weight"]
28
+ )
29
+ )
30
+ if is_train:
31
+ self.automatic_optimization = False
32
+ self.save_hyperparameters()
33
+ self.eval_dir = output_dir / "eval"
34
+ self.eval_dir.mkdir(parents=True, exist_ok=True)
35
+
36
+ def training_step(self, batch: Dict, batch_idx: int):
37
+ opt = self.optimizers()
38
+ scheduler = self.lr_schedulers()
39
+ loss, acc = self.model.forward(
40
+ batch["phoneme_ids"],
41
+ batch["phoneme_ids_len"],
42
+ batch["semantic_ids"],
43
+ batch["semantic_ids_len"],
44
+ batch["bert_feature"],
45
+ )
46
+ self.manual_backward(loss)
47
+ if batch_idx > 0 and batch_idx % 4 == 0:
48
+ opt.step()
49
+ opt.zero_grad()
50
+ scheduler.step()
51
+
52
+ self.log(
53
+ "total_loss",
54
+ loss,
55
+ on_step=True,
56
+ on_epoch=True,
57
+ prog_bar=True,
58
+ sync_dist=True,
59
+ )
60
+ self.log(
61
+ "lr",
62
+ scheduler.get_last_lr()[0],
63
+ on_epoch=True,
64
+ prog_bar=True,
65
+ sync_dist=True,
66
+ )
67
+ self.log(
68
+ f"top_{self.top_k}_acc",
69
+ acc,
70
+ on_step=True,
71
+ on_epoch=True,
72
+ prog_bar=True,
73
+ sync_dist=True,
74
+ )
75
+
76
+ def validation_step(self, batch: Dict, batch_idx: int):
77
+ return
78
+
79
+ def configure_optimizers(self):
80
+ model_parameters = self.model.parameters()
81
+ parameters_names = []
82
+ parameters_names.append(
83
+ [name_param_pair[0] for name_param_pair in self.model.named_parameters()]
84
+ )
85
+ lm_opt = ScaledAdam(
86
+ model_parameters,
87
+ lr=0.01,
88
+ betas=(0.9, 0.95),
89
+ clipping_scale=2.0,
90
+ parameters_names=parameters_names,
91
+ show_dominant_parameters=False,
92
+ clipping_update_period=1000,
93
+ )
94
+
95
+ return {
96
+ "optimizer": lm_opt,
97
+ "lr_scheduler": {
98
+ "scheduler": WarmupCosineLRSchedule(
99
+ lm_opt,
100
+ init_lr=self.config["optimizer"]["lr_init"],
101
+ peak_lr=self.config["optimizer"]["lr"],
102
+ end_lr=self.config["optimizer"]["lr_end"],
103
+ warmup_steps=self.config["optimizer"]["warmup_steps"],
104
+ total_steps=self.config["optimizer"]["decay_steps"],
105
+ )
106
+ },
107
+ }
AR/models/t2s_model.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import torch
4
+ import random
5
+ import numpy as np
6
+
7
+ from tqdm import tqdm
8
+ from typing import List
9
+ from AR.models.utils import make_pad_mask
10
+ from AR.models.utils import (
11
+ topk_sampling,
12
+ sample,
13
+ logits_to_probs,
14
+ multinomial_sample_one_no_sync,
15
+ dpo_loss,
16
+ make_reject_y,
17
+ get_batch_logps
18
+ )
19
+ from AR.modules.embedding import SinePositionalEmbedding
20
+ from AR.modules.embedding import TokenEmbedding
21
+ from AR.modules.transformer import LayerNorm
22
+ from AR.modules.transformer import TransformerEncoder
23
+ from AR.modules.transformer import TransformerEncoderLayer
24
+ from torch import nn
25
+ from torch.nn import functional as F
26
+ from torchmetrics.classification import MulticlassAccuracy
27
+
28
+ default_config = {
29
+ "embedding_dim": 512,
30
+ "hidden_dim": 512,
31
+ "num_head": 8,
32
+ "num_layers": 12,
33
+ "num_codebook": 8,
34
+ "p_dropout": 0.0,
35
+ "vocab_size": 1024 + 1,
36
+ "phoneme_vocab_size": 512,
37
+ "EOS": 1024,
38
+ }
39
+
40
+
41
+ @torch.jit.script
42
+ class T2SMLP:
43
+ def __init__(self, w1, b1, w2, b2):
44
+ self.w1 = w1
45
+ self.b1 = b1
46
+ self.w2 = w2
47
+ self.b2 = b2
48
+
49
+ def forward(self, x):
50
+ x = F.relu(F.linear(x, self.w1, self.b1))
51
+ x = F.linear(x, self.w2, self.b2)
52
+ return x
53
+
54
+
55
+ @torch.jit.script
56
+ class T2SBlock:
57
+ def __init__(
58
+ self,
59
+ num_heads,
60
+ hidden_dim: int,
61
+ mlp: T2SMLP,
62
+ qkv_w,
63
+ qkv_b,
64
+ out_w,
65
+ out_b,
66
+ norm_w1,
67
+ norm_b1,
68
+ norm_eps1,
69
+ norm_w2,
70
+ norm_b2,
71
+ norm_eps2,
72
+ ):
73
+ self.num_heads = num_heads
74
+ self.mlp = mlp
75
+ self.hidden_dim: int = hidden_dim
76
+ self.qkv_w = qkv_w
77
+ self.qkv_b = qkv_b
78
+ self.out_w = out_w
79
+ self.out_b = out_b
80
+ self.norm_w1 = norm_w1
81
+ self.norm_b1 = norm_b1
82
+ self.norm_eps1 = norm_eps1
83
+ self.norm_w2 = norm_w2
84
+ self.norm_b2 = norm_b2
85
+ self.norm_eps2 = norm_eps2
86
+
87
+ def process_prompt(self, x, attn_mask: torch.Tensor):
88
+ q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
89
+
90
+ batch_size = q.shape[0]
91
+ q_len = q.shape[1]
92
+ kv_len = k.shape[1]
93
+
94
+ k_cache = k
95
+ v_cache = v
96
+
97
+ q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
98
+ k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
99
+ v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
100
+
101
+ attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
102
+
103
+ attn = attn.permute(2, 0, 1, 3).reshape(batch_size, -1, self.hidden_dim)
104
+ attn = F.linear(attn, self.out_w, self.out_b)
105
+
106
+ x = F.layer_norm(
107
+ x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
108
+ )
109
+ x = F.layer_norm(
110
+ x + self.mlp.forward(x),
111
+ [self.hidden_dim],
112
+ self.norm_w2,
113
+ self.norm_b2,
114
+ self.norm_eps2,
115
+ )
116
+ return x, k_cache, v_cache
117
+
118
+ def decode_next_token(self, x, k_cache, v_cache):
119
+ q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
120
+
121
+ k_cache = torch.cat([k_cache, k], dim=1)
122
+ v_cache = torch.cat([v_cache, v], dim=1)
123
+ kv_len = k_cache.shape[1]
124
+
125
+ batch_size = q.shape[0]
126
+ q_len = q.shape[1]
127
+
128
+ q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
129
+ k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
130
+ v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
131
+
132
+ attn = F.scaled_dot_product_attention(q, k, v)
133
+
134
+ attn = attn.permute(2, 0, 1, 3).reshape(batch_size, -1, self.hidden_dim)
135
+ attn = F.linear(attn, self.out_w, self.out_b)
136
+
137
+ x = F.layer_norm(
138
+ x + attn, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
139
+ )
140
+ x = F.layer_norm(
141
+ x + self.mlp.forward(x),
142
+ [self.hidden_dim],
143
+ self.norm_w2,
144
+ self.norm_b2,
145
+ self.norm_eps2,
146
+ )
147
+ return x, k_cache, v_cache
148
+
149
+
150
+ @torch.jit.script
151
+ class T2STransformer:
152
+ def __init__(self, num_blocks: int, blocks: List[T2SBlock]):
153
+ self.num_blocks: int = num_blocks
154
+ self.blocks = blocks
155
+
156
+ def process_prompt(
157
+ self, x, attn_mask: torch.Tensor):
158
+ k_cache: List[torch.Tensor] = []
159
+ v_cache: List[torch.Tensor] = []
160
+ for i in range(self.num_blocks):
161
+ x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask)
162
+ k_cache.append(k_cache_)
163
+ v_cache.append(v_cache_)
164
+ return x, k_cache, v_cache
165
+
166
+ def decode_next_token(
167
+ self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]
168
+ ):
169
+ for i in range(self.num_blocks):
170
+ x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
171
+ return x, k_cache, v_cache
172
+
173
+
174
+ class Text2SemanticDecoder(nn.Module):
175
+ def __init__(self, config, norm_first=False, top_k=3):
176
+ super(Text2SemanticDecoder, self).__init__()
177
+ self.model_dim = config["model"]["hidden_dim"]
178
+ self.embedding_dim = config["model"]["embedding_dim"]
179
+ self.num_head = config["model"]["head"]
180
+ self.num_layers = config["model"]["n_layer"]
181
+ self.norm_first = norm_first
182
+ self.vocab_size = config["model"]["vocab_size"]
183
+ self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
184
+ self.p_dropout = config["model"]["dropout"]
185
+ self.EOS = config["model"]["EOS"]
186
+ self.norm_first = norm_first
187
+ assert self.EOS == self.vocab_size - 1
188
+ # should be same as num of kmeans bin
189
+ # assert self.EOS == 1024
190
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
191
+ self.ar_text_embedding = TokenEmbedding(
192
+ self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
193
+ )
194
+ self.ar_text_position = SinePositionalEmbedding(
195
+ self.embedding_dim, dropout=0.1, scale=False, alpha=True
196
+ )
197
+ self.ar_audio_embedding = TokenEmbedding(
198
+ self.embedding_dim, self.vocab_size, self.p_dropout
199
+ )
200
+ self.ar_audio_position = SinePositionalEmbedding(
201
+ self.embedding_dim, dropout=0.1, scale=False, alpha=True
202
+ )
203
+
204
+ self.h = TransformerEncoder(
205
+ TransformerEncoderLayer(
206
+ d_model=self.model_dim,
207
+ nhead=self.num_head,
208
+ dim_feedforward=self.model_dim * 4,
209
+ dropout=0.1,
210
+ batch_first=True,
211
+ norm_first=norm_first,
212
+ ),
213
+ num_layers=self.num_layers,
214
+ norm=LayerNorm(self.model_dim) if norm_first else None,
215
+ )
216
+
217
+ self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
218
+ self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
219
+
220
+ self.ar_accuracy_metric = MulticlassAccuracy(
221
+ self.vocab_size,
222
+ top_k=top_k,
223
+ average="micro",
224
+ multidim_average="global",
225
+ ignore_index=self.EOS,
226
+ )
227
+
228
+ blocks = []
229
+
230
+ for i in range(self.num_layers):
231
+ layer = self.h.layers[i]
232
+ t2smlp = T2SMLP(
233
+ layer.linear1.weight,
234
+ layer.linear1.bias,
235
+ layer.linear2.weight,
236
+ layer.linear2.bias
237
+ )
238
+ # (layer.self_attn.in_proj_weight, layer.self_attn.in_proj_bias)
239
+ block = T2SBlock(
240
+ self.num_head,
241
+ self.model_dim,
242
+ t2smlp,
243
+ layer.self_attn.in_proj_weight,
244
+ layer.self_attn.in_proj_bias,
245
+ layer.self_attn.out_proj.weight,
246
+ layer.self_attn.out_proj.bias,
247
+ layer.norm1.weight,
248
+ layer.norm1.bias,
249
+ layer.norm1.eps,
250
+ layer.norm2.weight,
251
+ layer.norm2.bias,
252
+ layer.norm2.eps
253
+ )
254
+
255
+ blocks.append(block)
256
+
257
+ self.t2s_transformer = T2STransformer(self.num_layers, blocks)
258
+
259
+ def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
260
+ x = self.ar_text_embedding(x)
261
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
262
+ x = self.ar_text_position(x)
263
+ x_mask = make_pad_mask(x_lens)
264
+
265
+ y_mask = make_pad_mask(y_lens)
266
+ y_mask_int = y_mask.type(torch.int64)
267
+ codes = y.type(torch.int64) * (1 - y_mask_int)
268
+
269
+ # Training
270
+ # AR Decoder
271
+ y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
272
+ x_len = x_lens.max()
273
+ y_len = y_lens.max()
274
+ y_emb = self.ar_audio_embedding(y)
275
+ y_pos = self.ar_audio_position(y_emb)
276
+
277
+ xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
278
+
279
+ ar_xy_padding_mask = xy_padding_mask
280
+
281
+ x_attn_mask = F.pad(
282
+ torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
283
+ (0, y_len),
284
+ value=True,
285
+ )
286
+
287
+ y_attn_mask = F.pad(
288
+ torch.triu(
289
+ torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
290
+ diagonal=1,
291
+ ),
292
+ (x_len, 0),
293
+ value=False,
294
+ )
295
+
296
+ xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
297
+ bsz, src_len = x.shape[0], x_len + y_len
298
+ _xy_padding_mask = (
299
+ ar_xy_padding_mask.view(bsz, 1, 1, src_len)
300
+ .expand(-1, self.num_head, -1, -1)
301
+ .reshape(bsz * self.num_head, 1, src_len)
302
+ )
303
+ xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
304
+ new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
305
+ new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
306
+ xy_attn_mask = new_attn_mask
307
+ # x 和完整的 y 一次性输入模型
308
+ xy_pos = torch.concat([x, y_pos], dim=1)
309
+
310
+ return xy_pos, xy_attn_mask, targets
311
+
312
+ def forward(self, x, x_lens, y, y_lens, bert_feature):
313
+ """
314
+ x: phoneme_ids
315
+ y: semantic_ids
316
+ """
317
+
318
+ reject_y, reject_y_lens = make_reject_y(y, y_lens)
319
+
320
+ xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature)
321
+
322
+ xy_dec, _ = self.h(
323
+ (xy_pos, None),
324
+ mask=xy_attn_mask,
325
+ )
326
+ x_len = x_lens.max()
327
+ logits = self.ar_predict_layer(xy_dec[:, x_len:])
328
+
329
+ ###### DPO #############
330
+ reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
331
+
332
+ reject_xy_dec, _ = self.h(
333
+ (reject_xy_pos, None),
334
+ mask=reject_xy_attn_mask,
335
+ )
336
+ x_len = x_lens.max()
337
+ reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len:])
338
+
339
+ # loss
340
+ # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
341
+
342
+ loss_1 = F.cross_entropy(logits.permute(0, 2, 1), targets, reduction="sum")
343
+ acc = self.ar_accuracy_metric(logits.permute(0, 2, 1).detach(), targets).item()
344
+
345
+ A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
346
+ loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
347
+
348
+ loss = loss_1 + loss_2
349
+
350
+ return loss, acc
351
+
352
+ def forward_old(self, x, x_lens, y, y_lens, bert_feature):
353
+ """
354
+ x: phoneme_ids
355
+ y: semantic_ids
356
+ """
357
+ x = self.ar_text_embedding(x)
358
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
359
+ x = self.ar_text_position(x)
360
+ x_mask = make_pad_mask(x_lens)
361
+
362
+ y_mask = make_pad_mask(y_lens)
363
+ y_mask_int = y_mask.type(torch.int64)
364
+ codes = y.type(torch.int64) * (1 - y_mask_int)
365
+
366
+ # Training
367
+ # AR Decoder
368
+ y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
369
+ x_len = x_lens.max()
370
+ y_len = y_lens.max()
371
+ y_emb = self.ar_audio_embedding(y)
372
+ y_pos = self.ar_audio_position(y_emb)
373
+
374
+ xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
375
+ ar_xy_padding_mask = xy_padding_mask
376
+
377
+ x_attn_mask = F.pad(
378
+ torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
379
+ (0, y_len),
380
+ value=True,
381
+ )
382
+ y_attn_mask = F.pad(
383
+ torch.triu(
384
+ torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
385
+ diagonal=1,
386
+ ),
387
+ (x_len, 0),
388
+ value=False,
389
+ )
390
+ xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
391
+ bsz, src_len = x.shape[0], x_len + y_len
392
+ _xy_padding_mask = (
393
+ ar_xy_padding_mask.view(bsz, 1, 1, src_len)
394
+ .expand(-1, self.num_head, -1, -1)
395
+ .reshape(bsz * self.num_head, 1, src_len)
396
+ )
397
+ xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
398
+ new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
399
+ new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
400
+ xy_attn_mask = new_attn_mask
401
+ # x 和完整的 y 一次性输入模型
402
+ xy_pos = torch.concat([x, y_pos], dim=1)
403
+ xy_dec, _ = self.h(
404
+ (xy_pos, None),
405
+ mask=xy_attn_mask,
406
+ )
407
+ logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
408
+ # loss
409
+ # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
410
+ loss = F.cross_entropy(logits, targets, reduction="sum")
411
+ acc = self.ar_accuracy_metric(logits.detach(), targets).item()
412
+ return loss, acc
413
+
414
+ # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
415
+ def infer(
416
+ self,
417
+ x,
418
+ x_lens,
419
+ prompts,
420
+ bert_feature,
421
+ top_k: int = -100,
422
+ early_stop_num: int = -1,
423
+ temperature: float = 1.0,
424
+ ):
425
+ x = self.ar_text_embedding(x)
426
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
427
+ x = self.ar_text_position(x)
428
+
429
+ # AR Decoder
430
+ y = prompts
431
+ prefix_len = y.shape[1]
432
+ x_len = x.shape[1]
433
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
434
+ stop = False
435
+ for _ in tqdm(range(1500)):
436
+ y_emb = self.ar_audio_embedding(y)
437
+ y_pos = self.ar_audio_position(y_emb)
438
+ # x 和逐渐增长的 y 一起输入给模型
439
+ xy_pos = torch.concat([x, y_pos], dim=1)
440
+ y_len = y.shape[1]
441
+ x_attn_mask_pad = F.pad(
442
+ x_attn_mask,
443
+ (0, y_len),
444
+ value=True,
445
+ )
446
+ y_attn_mask = F.pad(
447
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
448
+ (x_len, 0),
449
+ value=False,
450
+ )
451
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
452
+ y.device
453
+ )
454
+
455
+ xy_dec, _ = self.h(
456
+ (xy_pos, None),
457
+ mask=xy_attn_mask,
458
+ )
459
+ logits = self.ar_predict_layer(xy_dec[:, -1])
460
+ samples = topk_sampling(
461
+ logits, top_k=top_k, top_p=1.0, temperature=temperature
462
+ )
463
+
464
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
465
+ print("use early stop num:", early_stop_num)
466
+ stop = True
467
+
468
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
469
+ # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
470
+ stop = True
471
+ if stop:
472
+ if prompts.shape[1] == y.shape[1]:
473
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
474
+ print("bad zero prediction")
475
+ print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
476
+ break
477
+ # 本次生成的 semantic_ids 和之前的 y 构成新的 y
478
+ # print(samples.shape)#[1,1]#第一个1是bs
479
+ # import os
480
+ # os._exit(2333)
481
+ y = torch.concat([y, samples], dim=1)
482
+ return y
483
+
484
+ def pad_y_eos(self, y, y_mask_int, eos_id):
485
+ targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
486
+ y_mask_int, (0, 1), value=1
487
+ )
488
+ # 错位
489
+ return targets[:, :-1], targets[:, 1:]
490
+
491
+ def infer_panel(
492
+ self,
493
+ x, #####全部文本token
494
+ x_lens,
495
+ prompts, ####参考音频token
496
+ bert_feature,
497
+ top_k: int = -100,
498
+ top_p: int = 100,
499
+ early_stop_num: int = -1,
500
+ temperature: float = 1.0,
501
+ ):
502
+ x = self.ar_text_embedding(x)
503
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
504
+ x = self.ar_text_position(x)
505
+
506
+ # AR Decoder
507
+ y = prompts
508
+
509
+ x_len = x.shape[1]
510
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
511
+ stop = False
512
+ # print(1111111,self.num_layers)
513
+
514
+ k_cache = None
515
+ v_cache = None
516
+ ################### first step ##########################
517
+ if y is not None:
518
+ y_emb = self.ar_audio_embedding(y)
519
+ y_len = y_emb.shape[1]
520
+ prefix_len = y.shape[1]
521
+ y_pos = self.ar_audio_position(y_emb)
522
+ xy_pos = torch.concat([x, y_pos], dim=1)
523
+ ref_free = False
524
+ else:
525
+ y_emb = None
526
+ y_len = 0
527
+ prefix_len = 0
528
+ y_pos = None
529
+ xy_pos = x
530
+ y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
531
+ prompts = y
532
+ ref_free = True
533
+
534
+ x_attn_mask_pad = F.pad(
535
+ x_attn_mask,
536
+ (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
537
+ value=True,
538
+ )
539
+ y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
540
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
541
+ (x_len, 0),
542
+ value=False,
543
+ )
544
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
545
+ x.device
546
+ )
547
+
548
+ for idx in tqdm(range(1500)):
549
+ if xy_attn_mask is not None:
550
+ xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask)
551
+ else:
552
+ xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
553
+
554
+ logits = self.ar_predict_layer(
555
+ xy_dec[:, -1]
556
+ )
557
+
558
+ if idx == 0:
559
+ xy_attn_mask = None
560
+ logits = logits[:, :-1]
561
+ samples = sample(
562
+ logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
563
+ )[0].unsqueeze(0)
564
+
565
+ y = torch.concat([y, samples], dim=1)
566
+
567
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
568
+ print("use early stop num:", early_stop_num)
569
+ stop = True
570
+
571
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
572
+ stop = True
573
+ if stop:
574
+ if y.shape[1] == 0:
575
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
576
+ print("bad zero prediction")
577
+ print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
578
+ break
579
+
580
+ ####################### update next step ###################################
581
+ y_emb = self.ar_audio_embedding(y[:, -1:])
582
+ xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
583
+
584
+ if ref_free:
585
+ return y[:, :-1], 0
586
+ return y[:, :-1], idx - 1
AR/models/t2s_model_onnx.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import torch
4
+ from tqdm import tqdm
5
+
6
+ from AR.modules.embedding_onnx import SinePositionalEmbedding
7
+ from AR.modules.embedding_onnx import TokenEmbedding
8
+ from AR.modules.transformer_onnx import LayerNorm
9
+ from AR.modules.transformer_onnx import TransformerEncoder
10
+ from AR.modules.transformer_onnx import TransformerEncoderLayer
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+ from torchmetrics.classification import MulticlassAccuracy
14
+
15
+ default_config = {
16
+ "embedding_dim": 512,
17
+ "hidden_dim": 512,
18
+ "num_head": 8,
19
+ "num_layers": 12,
20
+ "num_codebook": 8,
21
+ "p_dropout": 0.0,
22
+ "vocab_size": 1024 + 1,
23
+ "phoneme_vocab_size": 512,
24
+ "EOS": 1024,
25
+ }
26
+
27
+ inf_tensor_value = torch.FloatTensor([-float("Inf")]).float()
28
+
29
+ def logits_to_probs(
30
+ logits,
31
+ previous_tokens = None,
32
+ temperature: float = 1.0,
33
+ top_k = None,
34
+ top_p = None,
35
+ repetition_penalty: float = 1.0,
36
+ ):
37
+ previous_tokens = previous_tokens.squeeze()
38
+ if previous_tokens is not None and repetition_penalty != 1.0:
39
+ previous_tokens = previous_tokens.long()
40
+ score = torch.gather(logits, dim=0, index=previous_tokens)
41
+ score = torch.where(
42
+ score < 0, score * repetition_penalty, score / repetition_penalty
43
+ )
44
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
45
+
46
+ if top_p is not None and top_p < 1.0:
47
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
48
+ cum_probs = torch.cumsum(
49
+ torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
50
+ )
51
+ sorted_indices_to_remove = cum_probs > top_p
52
+ sorted_indices_to_remove[0] = False # keep at least one option
53
+ indices_to_remove = sorted_indices_to_remove.scatter(
54
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
55
+ )
56
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
57
+
58
+ logits = logits / max(temperature, 1e-5)
59
+
60
+ if top_k is not None:
61
+ v, _ = torch.topk(logits, top_k)
62
+ pivot = v.select(-1, -1).unsqueeze(-1)
63
+ logits = torch.where(logits < pivot, inf_tensor_value, logits)
64
+
65
+ probs = torch.nn.functional.softmax(logits, dim=-1)
66
+ return probs
67
+
68
+
69
+ def multinomial_sample_one_no_sync(
70
+ probs_sort
71
+ ): # Does multinomial sampling without a cuda synchronization
72
+ q = torch.randn_like(probs_sort)
73
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
74
+
75
+
76
+ def sample(
77
+ logits,
78
+ previous_tokens,
79
+ **sampling_kwargs,
80
+ ):
81
+ probs = logits_to_probs(
82
+ logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
83
+ )
84
+ idx_next = multinomial_sample_one_no_sync(probs)
85
+ return idx_next, probs
86
+
87
+
88
+ class OnnxEncoder(nn.Module):
89
+ def __init__(self, ar_text_embedding, bert_proj, ar_text_position):
90
+ super().__init__()
91
+ self.ar_text_embedding = ar_text_embedding
92
+ self.bert_proj = bert_proj
93
+ self.ar_text_position = ar_text_position
94
+
95
+ def forward(self, x, bert_feature):
96
+ x = self.ar_text_embedding(x)
97
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
98
+ return self.ar_text_position(x)
99
+
100
+
101
+ class T2SFirstStageDecoder(nn.Module):
102
+ def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
103
+ top_k, early_stop_num, num_layers):
104
+ super().__init__()
105
+ self.ar_audio_embedding = ar_audio_embedding
106
+ self.ar_audio_position = ar_audio_position
107
+ self.h = h
108
+ self.ar_predict_layer = ar_predict_layer
109
+ self.loss_fct = loss_fct
110
+ self.ar_accuracy_metric = ar_accuracy_metric
111
+ self.top_k = top_k
112
+ self.early_stop_num = early_stop_num
113
+ self.num_layers = num_layers
114
+
115
+ def forward(self, x, prompt):
116
+ y = prompt
117
+ x_example = x[:,:,0] * 0.0
118
+ #N, 1, 512
119
+ cache = {
120
+ "all_stage": self.num_layers,
121
+ "k": None,
122
+ "v": None,
123
+ "y_emb": None,
124
+ "first_infer": 1,
125
+ "stage": 0,
126
+ }
127
+
128
+ y_emb = self.ar_audio_embedding(y)
129
+
130
+ cache["y_emb"] = y_emb
131
+ y_pos = self.ar_audio_position(y_emb)
132
+
133
+ xy_pos = torch.concat([x, y_pos], dim=1)
134
+
135
+ y_example = y_pos[:,:,0] * 0.0
136
+ x_attn_mask = torch.matmul(x_example.transpose(0, 1) , x_example).bool()
137
+ y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
138
+ y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
139
+ torch.ones_like(y_example.transpose(0, 1), dtype=torch.int64), dim=0
140
+ )
141
+ y_attn_mask = y_attn_mask > 0
142
+
143
+ x_y_pad = torch.matmul(x_example.transpose(0, 1), y_example).bool()
144
+ y_x_pad = torch.matmul(y_example.transpose(0, 1), x_example).bool()
145
+ x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
146
+ y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
147
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
148
+ cache["k"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
149
+ .unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
150
+ cache["v"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
151
+ .unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
152
+
153
+ xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
154
+ logits = self.ar_predict_layer(xy_dec[:, -1])
155
+ samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
156
+
157
+ y = torch.concat([y, samples], dim=1)
158
+
159
+ return y, cache["k"], cache["v"], cache["y_emb"], x_example
160
+
161
+
162
+ class T2SStageDecoder(nn.Module):
163
+ def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
164
+ top_k, early_stop_num, num_layers):
165
+ super().__init__()
166
+ self.ar_audio_embedding = ar_audio_embedding
167
+ self.ar_audio_position = ar_audio_position
168
+ self.h = h
169
+ self.ar_predict_layer = ar_predict_layer
170
+ self.loss_fct = loss_fct
171
+ self.ar_accuracy_metric = ar_accuracy_metric
172
+ self.top_k = top_k
173
+ self.early_stop_num = early_stop_num
174
+ self.num_layers = num_layers
175
+
176
+ def forward(self, y, k, v, y_emb, x_example):
177
+ cache = {
178
+ "all_stage": self.num_layers,
179
+ "k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)),
180
+ "v": torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)),
181
+ "y_emb": y_emb,
182
+ "first_infer": 0,
183
+ "stage": 0,
184
+ }
185
+
186
+ y_emb = torch.cat(
187
+ [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
188
+ )
189
+ cache["y_emb"] = y_emb
190
+ y_pos = self.ar_audio_position(y_emb)
191
+
192
+ xy_pos = y_pos[:, -1:]
193
+
194
+ y_example = y_pos[:,:,0] * 0.0
195
+
196
+ xy_attn_mask = torch.cat([x_example, y_example], dim=1)
197
+ xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
198
+
199
+ xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
200
+ logits = self.ar_predict_layer(xy_dec[:, -1])
201
+ samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
202
+
203
+ y = torch.concat([y, samples], dim=1)
204
+
205
+ return y, cache["k"], cache["v"], cache["y_emb"], logits, samples
206
+
207
+
208
+ class Text2SemanticDecoder(nn.Module):
209
+ def __init__(self, config, norm_first=False, top_k=3):
210
+ super(Text2SemanticDecoder, self).__init__()
211
+ self.model_dim = config["model"]["hidden_dim"]
212
+ self.embedding_dim = config["model"]["embedding_dim"]
213
+ self.num_head = config["model"]["head"]
214
+ self.num_layers = config["model"]["n_layer"]
215
+ self.norm_first = norm_first
216
+ self.vocab_size = config["model"]["vocab_size"]
217
+ self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
218
+ self.p_dropout = float(config["model"]["dropout"])
219
+ self.EOS = config["model"]["EOS"]
220
+ self.norm_first = norm_first
221
+ assert self.EOS == self.vocab_size - 1
222
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
223
+ self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
224
+ self.ar_text_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True)
225
+ self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size, self.p_dropout)
226
+ self.ar_audio_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True)
227
+ self.h = TransformerEncoder(
228
+ TransformerEncoderLayer(
229
+ d_model=self.model_dim,
230
+ nhead=self.num_head,
231
+ dim_feedforward=self.model_dim * 4,
232
+ dropout=0.1,
233
+ batch_first=True,
234
+ norm_first=norm_first,
235
+ ),
236
+ num_layers=self.num_layers,
237
+ norm=LayerNorm(self.model_dim) if norm_first else None,
238
+ )
239
+ self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
240
+ self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
241
+ self.ar_accuracy_metric = MulticlassAccuracy(
242
+ self.vocab_size,
243
+ top_k=top_k,
244
+ average="micro",
245
+ multidim_average="global",
246
+ ignore_index=self.EOS,
247
+ )
248
+ self.top_k = torch.LongTensor([1])
249
+ self.early_stop_num = torch.LongTensor([-1])
250
+
251
+ def init_onnx(self):
252
+ self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
253
+ self.first_stage_decoder = T2SFirstStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
254
+ self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
255
+ self.num_layers)
256
+ self.stage_decoder = T2SStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
257
+ self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
258
+ self.num_layers)
259
+
260
+ def forward(self, x, prompts, bert_feature):
261
+ early_stop_num = self.early_stop_num
262
+ prefix_len = prompts.shape[1]
263
+
264
+ x = self.onnx_encoder(x, bert_feature)
265
+ y, k, v, y_emb, stage, x_example = self.first_stage_decoder(x, prompts)
266
+
267
+ stop = False
268
+ for idx in range(1, 1500):
269
+ enco = self.stage_decoder(y, k, v, y_emb, stage, x_example)
270
+ y, k, v, y_emb, stage, logits, samples = enco
271
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
272
+ stop = True
273
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
274
+ stop = True
275
+ if stop:
276
+ break
277
+ y[0, -1] = 0
278
+ return y, idx
279
+
280
+ def infer(self, x, prompts, bert_feature):
281
+ top_k = self.top_k
282
+ early_stop_num = self.early_stop_num
283
+
284
+ x = self.onnx_encoder(x, bert_feature)
285
+
286
+ y = prompts
287
+ prefix_len = y.shape[1]
288
+ x_len = x.shape[1]
289
+ x_example = x[:,:,0] * 0.0
290
+ x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example)
291
+ x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool)
292
+
293
+ stop = False
294
+ cache = {
295
+ "all_stage": self.num_layers,
296
+ "k": [None] * self.num_layers,
297
+ "v": [None] * self.num_layers,
298
+ "y_emb": None,
299
+ "first_infer": 1,
300
+ "stage": 0,
301
+ }
302
+ for idx in range(1500):
303
+ if cache["first_infer"] == 1:
304
+ y_emb = self.ar_audio_embedding(y)
305
+ else:
306
+ y_emb = torch.cat(
307
+ [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
308
+ )
309
+ cache["y_emb"] = y_emb
310
+ y_pos = self.ar_audio_position(y_emb)
311
+ if cache["first_infer"] == 1:
312
+ xy_pos = torch.concat([x, y_pos], dim=1)
313
+ else:
314
+ xy_pos = y_pos[:, -1:]
315
+ y_len = y_pos.shape[1]
316
+ if cache["first_infer"] == 1:
317
+ x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
318
+ y_attn_mask = F.pad(
319
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
320
+ (x_len, 0), value=False
321
+ )
322
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
323
+ else:
324
+ xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
325
+ xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
326
+ logits = self.ar_predict_layer(xy_dec[:, -1])
327
+ samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
328
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
329
+ stop = True
330
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
331
+ stop = True
332
+ if stop:
333
+ if prompts.shape[1] == y.shape[1]:
334
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
335
+ break
336
+ y = torch.concat([y, samples], dim=1)
337
+ cache["first_infer"] = 0
338
+ return y, idx
AR/models/utils.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from typing import Tuple
6
+
7
+ def sequence_mask(length, max_length=None):
8
+ if max_length is None:
9
+ max_length = length.max()
10
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
11
+ return x.unsqueeze(0) < length.unsqueeze(1)
12
+
13
+
14
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
15
+ """
16
+ Args:
17
+ lengths:
18
+ A 1-D tensor containing sentence lengths.
19
+ max_len:
20
+ The length of masks.
21
+ Returns:
22
+ Return a 2-D bool tensor, where masked positions
23
+ are filled with `True` and non-masked positions are
24
+ filled with `False`.
25
+
26
+ #>>> lengths = torch.tensor([1, 3, 2, 5])
27
+ #>>> make_pad_mask(lengths)
28
+ tensor([[False, True, True, True, True],
29
+ [False, False, False, True, True],
30
+ [False, False, True, True, True],
31
+ [False, False, False, False, False]])
32
+ """
33
+ assert lengths.ndim == 1, lengths.ndim
34
+ max_len = max(max_len, lengths.max())
35
+ n = lengths.size(0)
36
+ seq_range = torch.arange(0, max_len, device=lengths.device)
37
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
38
+
39
+ return expaned_lengths >= lengths.unsqueeze(-1)
40
+
41
+
42
+ # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
43
+ def top_k_top_p_filtering(
44
+ logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
45
+ ):
46
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
47
+ Args:
48
+ logits: logits distribution shape (batch size, vocabulary size)
49
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
50
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
51
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
52
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
53
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
54
+ """
55
+ if top_k > 0:
56
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
57
+ # Remove all tokens with a probability less than the last token of the top-k
58
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
59
+ logits[indices_to_remove] = filter_value
60
+
61
+ if top_p < 1.0:
62
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
63
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
64
+
65
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
66
+ sorted_indices_to_remove = cumulative_probs > top_p
67
+ if min_tokens_to_keep > 1:
68
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
69
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
70
+ # Shift the indices to the right to keep also the first token above the threshold
71
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
72
+ sorted_indices_to_remove[..., 0] = 0
73
+
74
+ # scatter sorted tensors to original indexing
75
+ indices_to_remove = sorted_indices_to_remove.scatter(
76
+ 1, sorted_indices, sorted_indices_to_remove
77
+ )
78
+ logits[indices_to_remove] = filter_value
79
+ return logits
80
+
81
+
82
+ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
83
+ # temperature: (`optional`) float
84
+ # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
85
+ # top_k: (`optional`) int
86
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
87
+ # top_p: (`optional`) float
88
+ # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
89
+
90
+ # Temperature (higher temperature => more likely to sample low probability tokens)
91
+ if temperature != 1.0:
92
+ logits = logits / temperature
93
+ # Top-p/top-k filtering
94
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
95
+ # Sample
96
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
97
+ return token
98
+
99
+
100
+ from typing import Optional, Tuple
101
+
102
+
103
+ def multinomial_sample_one_no_sync(
104
+ probs_sort,
105
+ ): # Does multinomial sampling without a cuda synchronization
106
+ q = torch.empty_like(probs_sort).exponential_(1)
107
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
108
+
109
+
110
+ def logits_to_probs(
111
+ logits,
112
+ previous_tokens: Optional[torch.Tensor] = None,
113
+ temperature: float = 1.0,
114
+ top_k: Optional[int] = None,
115
+ top_p: Optional[int] = None,
116
+ repetition_penalty: float = 1.0,
117
+ ):
118
+ if previous_tokens is not None:
119
+ previous_tokens = previous_tokens.squeeze()
120
+ # print(logits.shape,previous_tokens.shape)
121
+ # pdb.set_trace()
122
+ if previous_tokens is not None and repetition_penalty != 1.0:
123
+ previous_tokens = previous_tokens.long()
124
+ score = torch.gather(logits, dim=0, index=previous_tokens)
125
+ score = torch.where(
126
+ score < 0, score * repetition_penalty, score / repetition_penalty
127
+ )
128
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
129
+
130
+ if top_p is not None and top_p < 1.0:
131
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
132
+ cum_probs = torch.cumsum(
133
+ torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
134
+ )
135
+ sorted_indices_to_remove = cum_probs > top_p
136
+ sorted_indices_to_remove[0] = False # keep at least one option
137
+ indices_to_remove = sorted_indices_to_remove.scatter(
138
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
139
+ )
140
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
141
+
142
+ logits = logits / max(temperature, 1e-5)
143
+
144
+ if top_k is not None:
145
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
146
+ pivot = v.select(-1, -1).unsqueeze(-1)
147
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
148
+
149
+ probs = torch.nn.functional.softmax(logits, dim=-1)
150
+ return probs
151
+
152
+
153
+ def sample(
154
+ logits,
155
+ previous_tokens: Optional[torch.Tensor] = None,
156
+ **sampling_kwargs,
157
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
158
+ probs = logits_to_probs(
159
+ logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
160
+ )
161
+ idx_next = multinomial_sample_one_no_sync(probs)
162
+ return idx_next, probs
163
+
164
+ def dpo_loss(policy_chosen_logps: torch.FloatTensor,
165
+ policy_rejected_logps: torch.FloatTensor,
166
+ reference_chosen_logps: torch.FloatTensor,
167
+ reference_rejected_logps: torch.FloatTensor,
168
+ beta: float,
169
+ reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
170
+ pi_logratios = policy_chosen_logps - policy_rejected_logps
171
+ ref_logratios = reference_chosen_logps - reference_rejected_logps
172
+
173
+ if reference_free:
174
+ ref_logratios = 0
175
+
176
+ logits = pi_logratios - ref_logratios
177
+
178
+ losses = -F.logsigmoid(beta * logits)
179
+ chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
180
+ rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()
181
+
182
+ return losses.mean(), chosen_rewards, rejected_rewards
183
+
184
+ def get_batch_logps(logits_target: torch.FloatTensor, logits_reject: torch.FloatTensor, labels_target: torch.LongTensor, labels_reject: torch.LongTensor, average_log_prob: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
185
+
186
+ # dummy token; we'll ignore the losses on these tokens later
187
+
188
+ per_token_logps_target = torch.gather(logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)).squeeze(2)
189
+ per_token_logps_reject = torch.gather(logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)).squeeze(2)
190
+
191
+ return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
192
+
193
+ def make_reject_y(y_o, y_lens):
194
+ def repeat_P(y):
195
+ range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
196
+ pre = y[:range_idx[0]]
197
+ shf = y[range_idx[1]:]
198
+ range_text = y[range_idx[0]:range_idx[1]]
199
+ new_y = torch.cat([pre, range_text, range_text, shf])
200
+ return new_y
201
+ def lost_P(y):
202
+ range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
203
+ pre = y[:range_idx[0]]
204
+ shf = y[range_idx[1]:]
205
+ range_text = y[range_idx[0]:range_idx[1]]
206
+ new_y = torch.cat([pre, shf])
207
+ return new_y
208
+ bs = len(y_lens)
209
+ reject_y = []
210
+ reject_y_lens = []
211
+ for b in range(bs):
212
+ process_item_idx = torch.randint(0, 1, size=(1, ))[0]
213
+ if process_item_idx == 0:
214
+ new_y = repeat_P(y_o[b])
215
+ reject_y.append(new_y)
216
+ reject_y_lens.append(len(new_y))
217
+ elif process_item_idx==1:
218
+ new_y = lost_P(y_o[b])
219
+ reject_y.append(new_y)
220
+ reject_y_lens.append(len(new_y))
221
+ max_length = max(reject_y_lens)
222
+ for b in range(bs):
223
+ pad_length = max_length - reject_y_lens[b]
224
+ reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0)
225
+
226
+ reject_y = torch.stack(reject_y, dim = 0)
227
+ reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device)
228
+
229
+ return reject_y, reject_y_lens
AR/modules/__init__.py ADDED
File without changes
AR/modules/activation.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import Linear
7
+ from torch.nn import Module
8
+ from torch.nn.init import constant_
9
+ from torch.nn.init import xavier_normal_
10
+ from torch.nn.init import xavier_uniform_
11
+ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
12
+ from torch.nn.parameter import Parameter
13
+
14
+ from torch.nn import functional as F
15
+ from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
16
+
17
+ F.multi_head_attention_forward = multi_head_attention_forward_patched
18
+
19
+
20
+ class MultiheadAttention(Module):
21
+ r"""Allows the model to jointly attend to information
22
+ from different representation subspaces as described in the paper:
23
+ `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
24
+
25
+ Multi-Head Attention is defined as:
26
+
27
+ .. math::
28
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
29
+
30
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
31
+
32
+ ``forward()`` will use a special optimized implementation if all of the following
33
+ conditions are met:
34
+
35
+ - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
36
+ restriction will be loosened in the future.)
37
+ - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
38
+ - training is disabled (using ``.eval()``)
39
+ - dropout is 0
40
+ - ``add_bias_kv`` is ``False``
41
+ - ``add_zero_attn`` is ``False``
42
+ - ``batch_first`` is ``True`` and the input is batched
43
+ - ``kdim`` and ``vdim`` are equal to ``embed_dim``
44
+ - at most one of ``key_padding_mask`` or ``attn_mask`` is passed
45
+ - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
46
+ nor ``attn_mask`` is passed
47
+
48
+ If the optimized implementation is in use, a
49
+ `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
50
+ ``query``/``key``/``value`` to represent padding more efficiently than using a
51
+ padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
52
+ will be returned, and an additional speedup proportional to the fraction of the input
53
+ that is padding can be expected.
54
+
55
+ Args:
56
+ embed_dim: Total dimension of the model.
57
+ num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
58
+ across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
59
+ dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
60
+ bias: If specified, adds bias to input / output projection layers. Default: ``True``.
61
+ add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
62
+ add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
63
+ Default: ``False``.
64
+ kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
65
+ vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
66
+ batch_first: If ``True``, then the input and output tensors are provided
67
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
68
+
69
+ Examples::
70
+
71
+ >>> # xdoctest: +SKIP
72
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
73
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
74
+
75
+ """
76
+ __constants__ = ["batch_first"]
77
+ bias_k: Optional[torch.Tensor]
78
+ bias_v: Optional[torch.Tensor]
79
+
80
+ def __init__(
81
+ self,
82
+ embed_dim,
83
+ num_heads,
84
+ dropout=0.0,
85
+ bias=True,
86
+ add_bias_kv=False,
87
+ add_zero_attn=False,
88
+ kdim=None,
89
+ vdim=None,
90
+ batch_first=False,
91
+ linear1_cls=Linear,
92
+ linear2_cls=Linear,
93
+ device=None,
94
+ dtype=None,
95
+ ) -> None:
96
+ factory_kwargs = {"device": device, "dtype": dtype}
97
+ super(MultiheadAttention, self).__init__()
98
+ self.embed_dim = embed_dim
99
+ self.kdim = kdim if kdim is not None else embed_dim
100
+ self.vdim = vdim if vdim is not None else embed_dim
101
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
102
+
103
+ self.num_heads = num_heads
104
+ self.dropout = dropout
105
+ self.batch_first = batch_first
106
+ self.head_dim = embed_dim // num_heads
107
+ assert (
108
+ self.head_dim * num_heads == self.embed_dim
109
+ ), "embed_dim must be divisible by num_heads"
110
+
111
+ if add_bias_kv:
112
+ self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
113
+ self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
114
+ else:
115
+ self.bias_k = self.bias_v = None
116
+
117
+ if linear1_cls == Linear:
118
+ if not self._qkv_same_embed_dim:
119
+ self.q_proj_weight = Parameter(
120
+ torch.empty((embed_dim, embed_dim), **factory_kwargs)
121
+ )
122
+ self.k_proj_weight = Parameter(
123
+ torch.empty((embed_dim, self.kdim), **factory_kwargs)
124
+ )
125
+ self.v_proj_weight = Parameter(
126
+ torch.empty((embed_dim, self.vdim), **factory_kwargs)
127
+ )
128
+ self.register_parameter("in_proj_weight", None)
129
+ else:
130
+ self.in_proj_weight = Parameter(
131
+ torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
132
+ )
133
+ self.register_parameter("q_proj_weight", None)
134
+ self.register_parameter("k_proj_weight", None)
135
+ self.register_parameter("v_proj_weight", None)
136
+
137
+ if bias:
138
+ self.in_proj_bias = Parameter(
139
+ torch.empty(3 * embed_dim, **factory_kwargs)
140
+ )
141
+ else:
142
+ self.register_parameter("in_proj_bias", None)
143
+ self.out_proj = NonDynamicallyQuantizableLinear(
144
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
145
+ )
146
+
147
+ self._reset_parameters()
148
+ else:
149
+ if not self._qkv_same_embed_dim:
150
+ raise NotImplementedError
151
+ else:
152
+ self.in_proj_linear = linear1_cls(
153
+ embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
154
+ )
155
+ self.in_proj_weight = self.in_proj_linear.weight
156
+
157
+ self.register_parameter("q_proj_weight", None)
158
+ self.register_parameter("k_proj_weight", None)
159
+ self.register_parameter("v_proj_weight", None)
160
+
161
+ if bias:
162
+ self.in_proj_bias = self.in_proj_linear.bias
163
+ else:
164
+ self.register_parameter("in_proj_bias", None)
165
+
166
+ self.out_proj = linear2_cls(
167
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
168
+ )
169
+
170
+ if self.bias_k is not None:
171
+ xavier_normal_(self.bias_k)
172
+ if self.bias_v is not None:
173
+ xavier_normal_(self.bias_v)
174
+
175
+ self.add_zero_attn = add_zero_attn
176
+
177
+ def _reset_parameters(self):
178
+ if self._qkv_same_embed_dim:
179
+ xavier_uniform_(self.in_proj_weight)
180
+ else:
181
+ xavier_uniform_(self.q_proj_weight)
182
+ xavier_uniform_(self.k_proj_weight)
183
+ xavier_uniform_(self.v_proj_weight)
184
+
185
+ if self.in_proj_bias is not None:
186
+ constant_(self.in_proj_bias, 0.0)
187
+ constant_(self.out_proj.bias, 0.0)
188
+
189
+ if self.bias_k is not None:
190
+ xavier_normal_(self.bias_k)
191
+ if self.bias_v is not None:
192
+ xavier_normal_(self.bias_v)
193
+
194
+ def __setstate__(self, state):
195
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
196
+ if "_qkv_same_embed_dim" not in state:
197
+ state["_qkv_same_embed_dim"] = True
198
+
199
+ super(MultiheadAttention, self).__setstate__(state)
200
+
201
+ def forward(
202
+ self,
203
+ query: Tensor,
204
+ key: Tensor,
205
+ value: Tensor,
206
+ key_padding_mask: Optional[Tensor] = None,
207
+ need_weights: bool = True,
208
+ attn_mask: Optional[Tensor] = None,
209
+ average_attn_weights: bool = True,
210
+ cache=None,
211
+ ) -> Tuple[Tensor, Optional[Tensor]]:
212
+ r"""
213
+ Args:
214
+ query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
215
+ or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
216
+ :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
217
+ Queries are compared against key-value pairs to produce the output.
218
+ See "Attention Is All You Need" for more details.
219
+ key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
220
+ or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
221
+ :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
222
+ See "Attention Is All You Need" for more details.
223
+ value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
224
+ ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
225
+ sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
226
+ See "Attention Is All You Need" for more details.
227
+ key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
228
+ to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
229
+ Binary and byte masks are supported.
230
+ For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
231
+ the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
232
+ need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
233
+ Default: ``True``.
234
+ attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
235
+ :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
236
+ :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
237
+ broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
238
+ Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
239
+ corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
240
+ corresponding position is not allowed to attend. For a float mask, the mask values will be added to
241
+ the attention weight.
242
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
243
+ heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
244
+ effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
245
+
246
+ Outputs:
247
+ - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
248
+ :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
249
+ where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
250
+ embedding dimension ``embed_dim``.
251
+ - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
252
+ returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
253
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
254
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
255
+ head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
256
+
257
+ .. note::
258
+ `batch_first` argument is ignored for unbatched inputs.
259
+ """
260
+ is_batched = query.dim() == 3
261
+ if key_padding_mask is not None:
262
+ _kpm_dtype = key_padding_mask.dtype
263
+ if _kpm_dtype != torch.bool and not torch.is_floating_point(
264
+ key_padding_mask
265
+ ):
266
+ raise AssertionError(
267
+ "only bool and floating types of key_padding_mask are supported"
268
+ )
269
+ why_not_fast_path = ""
270
+ if not is_batched:
271
+ why_not_fast_path = (
272
+ f"input not batched; expected query.dim() of 3 but got {query.dim()}"
273
+ )
274
+ elif query is not key or key is not value:
275
+ # When lifting this restriction, don't forget to either
276
+ # enforce that the dtypes all match or test cases where
277
+ # they don't!
278
+ why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
279
+ elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
280
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
281
+ elif (
282
+ self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
283
+ ):
284
+ # this case will fail anyway, but at least they'll get a useful error message.
285
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
286
+ elif self.training:
287
+ why_not_fast_path = "training is enabled"
288
+ elif not self.batch_first:
289
+ why_not_fast_path = "batch_first was not True"
290
+ elif self.bias_k is not None:
291
+ why_not_fast_path = "self.bias_k was not None"
292
+ elif self.bias_v is not None:
293
+ why_not_fast_path = "self.bias_v was not None"
294
+ elif self.dropout:
295
+ why_not_fast_path = f"dropout was {self.dropout}, required zero"
296
+ elif self.add_zero_attn:
297
+ why_not_fast_path = "add_zero_attn was enabled"
298
+ elif not self._qkv_same_embed_dim:
299
+ why_not_fast_path = "_qkv_same_embed_dim was not True"
300
+ elif attn_mask is not None:
301
+ why_not_fast_path = "attn_mask was not None"
302
+ elif query.is_nested and key_padding_mask is not None:
303
+ why_not_fast_path = (
304
+ "key_padding_mask is not supported with NestedTensor input"
305
+ )
306
+ elif self.num_heads % 2 == 1:
307
+ why_not_fast_path = "num_heads is odd"
308
+ elif torch.is_autocast_enabled():
309
+ why_not_fast_path = "autocast is enabled"
310
+
311
+ if not why_not_fast_path:
312
+ tensor_args = (
313
+ query,
314
+ key,
315
+ value,
316
+ self.in_proj_weight,
317
+ self.in_proj_bias,
318
+ self.out_proj.weight,
319
+ self.out_proj.bias,
320
+ )
321
+ # We have to use list comprehensions below because TorchScript does not support
322
+ # generator expressions.
323
+ if torch.overrides.has_torch_function(tensor_args):
324
+ why_not_fast_path = "some Tensor argument has_torch_function"
325
+ elif not all(
326
+ [
327
+ (x is None or x.is_cuda or "cpu" in str(x.device))
328
+ for x in tensor_args
329
+ ]
330
+ ):
331
+ why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
332
+ elif torch.is_grad_enabled() and any(
333
+ [x is not None and x.requires_grad for x in tensor_args]
334
+ ):
335
+ why_not_fast_path = (
336
+ "grad is enabled and at least one of query or the "
337
+ "input/output projection weights or biases requires_grad"
338
+ )
339
+ if not why_not_fast_path:
340
+ return torch._native_multi_head_attention(
341
+ query,
342
+ key,
343
+ value,
344
+ self.embed_dim,
345
+ self.num_heads,
346
+ self.in_proj_weight,
347
+ self.in_proj_bias,
348
+ self.out_proj.weight,
349
+ self.out_proj.bias,
350
+ key_padding_mask if key_padding_mask is not None else attn_mask,
351
+ need_weights,
352
+ average_attn_weights,
353
+ 1
354
+ if key_padding_mask is not None
355
+ else 0
356
+ if attn_mask is not None
357
+ else None,
358
+ )
359
+
360
+ any_nested = query.is_nested or key.is_nested or value.is_nested
361
+ assert not any_nested, (
362
+ "MultiheadAttention does not support NestedTensor outside of its fast path. "
363
+ + f"The fast path was not hit because {why_not_fast_path}"
364
+ )
365
+
366
+ if self.batch_first and is_batched:
367
+ # make sure that the transpose op does not affect the "is" property
368
+ if key is value:
369
+ if query is key:
370
+ query = key = value = query.transpose(1, 0)
371
+ else:
372
+ query, key = [x.transpose(1, 0) for x in (query, key)]
373
+ value = key
374
+ else:
375
+ query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
376
+
377
+ if not self._qkv_same_embed_dim:
378
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
379
+ query,
380
+ key,
381
+ value,
382
+ self.embed_dim,
383
+ self.num_heads,
384
+ self.in_proj_weight,
385
+ self.in_proj_bias,
386
+ self.bias_k,
387
+ self.bias_v,
388
+ self.add_zero_attn,
389
+ self.dropout,
390
+ self.out_proj.weight,
391
+ self.out_proj.bias,
392
+ training=self.training,
393
+ key_padding_mask=key_padding_mask,
394
+ need_weights=need_weights,
395
+ attn_mask=attn_mask,
396
+ use_separate_proj_weight=True,
397
+ q_proj_weight=self.q_proj_weight,
398
+ k_proj_weight=self.k_proj_weight,
399
+ v_proj_weight=self.v_proj_weight,
400
+ average_attn_weights=average_attn_weights,
401
+ cache=cache,
402
+ )
403
+ else:
404
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
405
+ query,
406
+ key,
407
+ value,
408
+ self.embed_dim,
409
+ self.num_heads,
410
+ self.in_proj_weight,
411
+ self.in_proj_bias,
412
+ self.bias_k,
413
+ self.bias_v,
414
+ self.add_zero_attn,
415
+ self.dropout,
416
+ self.out_proj.weight,
417
+ self.out_proj.bias,
418
+ training=self.training,
419
+ key_padding_mask=key_padding_mask,
420
+ need_weights=need_weights,
421
+ attn_mask=attn_mask,
422
+ average_attn_weights=average_attn_weights,
423
+ cache=cache,
424
+ )
425
+ if self.batch_first and is_batched:
426
+ return attn_output.transpose(1, 0), attn_output_weights
427
+ else:
428
+ return attn_output, attn_output_weights
AR/modules/activation_onnx.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import Linear
7
+ from torch.nn import Module
8
+ from torch.nn.init import constant_
9
+ from torch.nn.init import xavier_normal_
10
+ from torch.nn.init import xavier_uniform_
11
+ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
12
+ from torch.nn.parameter import Parameter
13
+
14
+ from torch.nn import functional as F
15
+ from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
16
+
17
+
18
+ class MultiheadAttention(Module):
19
+ __constants__ = ["batch_first"]
20
+ bias_k: Optional[torch.Tensor]
21
+ bias_v: Optional[torch.Tensor]
22
+
23
+ def __init__(
24
+ self,
25
+ embed_dim,
26
+ num_heads,
27
+ dropout=0.0,
28
+ bias=True,
29
+ add_bias_kv=False,
30
+ add_zero_attn=False,
31
+ kdim=None,
32
+ vdim=None,
33
+ batch_first=False,
34
+ linear1_cls=Linear,
35
+ linear2_cls=Linear,
36
+ device=None,
37
+ dtype=None,
38
+ ) -> None:
39
+ factory_kwargs = {"device": device, "dtype": dtype}
40
+ super(MultiheadAttention, self).__init__()
41
+ self.embed_dim = embed_dim
42
+ self.kdim = kdim if kdim is not None else embed_dim
43
+ self.vdim = vdim if vdim is not None else embed_dim
44
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
45
+
46
+ self.num_heads = num_heads
47
+ self.dropout = dropout
48
+ self.batch_first = batch_first
49
+ self.head_dim = embed_dim // num_heads
50
+ assert (
51
+ self.head_dim * num_heads == self.embed_dim
52
+ ), "embed_dim must be divisible by num_heads"
53
+
54
+ if add_bias_kv:
55
+ self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
56
+ self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
57
+ else:
58
+ self.bias_k = self.bias_v = None
59
+
60
+ if linear1_cls == Linear:
61
+ if not self._qkv_same_embed_dim:
62
+ self.q_proj_weight = Parameter(
63
+ torch.empty((embed_dim, embed_dim), **factory_kwargs)
64
+ )
65
+ self.k_proj_weight = Parameter(
66
+ torch.empty((embed_dim, self.kdim), **factory_kwargs)
67
+ )
68
+ self.v_proj_weight = Parameter(
69
+ torch.empty((embed_dim, self.vdim), **factory_kwargs)
70
+ )
71
+ self.register_parameter("in_proj_weight", None)
72
+ else:
73
+ self.in_proj_weight = Parameter(
74
+ torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
75
+ )
76
+ self.register_parameter("q_proj_weight", None)
77
+ self.register_parameter("k_proj_weight", None)
78
+ self.register_parameter("v_proj_weight", None)
79
+
80
+ if bias:
81
+ self.in_proj_bias = Parameter(
82
+ torch.empty(3 * embed_dim, **factory_kwargs)
83
+ )
84
+ else:
85
+ self.register_parameter("in_proj_bias", None)
86
+ self.out_proj = NonDynamicallyQuantizableLinear(
87
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
88
+ )
89
+
90
+ self._reset_parameters()
91
+ else:
92
+ if not self._qkv_same_embed_dim:
93
+ raise NotImplementedError
94
+ else:
95
+ self.in_proj_linear = linear1_cls(
96
+ embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
97
+ )
98
+ self.in_proj_weight = self.in_proj_linear.weight
99
+
100
+ self.register_parameter("q_proj_weight", None)
101
+ self.register_parameter("k_proj_weight", None)
102
+ self.register_parameter("v_proj_weight", None)
103
+
104
+ if bias:
105
+ self.in_proj_bias = self.in_proj_linear.bias
106
+ else:
107
+ self.register_parameter("in_proj_bias", None)
108
+
109
+ self.out_proj = linear2_cls(
110
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
111
+ )
112
+
113
+ if self.bias_k is not None:
114
+ xavier_normal_(self.bias_k)
115
+ if self.bias_v is not None:
116
+ xavier_normal_(self.bias_v)
117
+
118
+ self.add_zero_attn = add_zero_attn
119
+
120
+ def _reset_parameters(self):
121
+ if self._qkv_same_embed_dim:
122
+ xavier_uniform_(self.in_proj_weight)
123
+ else:
124
+ xavier_uniform_(self.q_proj_weight)
125
+ xavier_uniform_(self.k_proj_weight)
126
+ xavier_uniform_(self.v_proj_weight)
127
+
128
+ if self.in_proj_bias is not None:
129
+ constant_(self.in_proj_bias, 0.0)
130
+ constant_(self.out_proj.bias, 0.0)
131
+
132
+ if self.bias_k is not None:
133
+ xavier_normal_(self.bias_k)
134
+ if self.bias_v is not None:
135
+ xavier_normal_(self.bias_v)
136
+
137
+ def __setstate__(self, state):
138
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
139
+ if "_qkv_same_embed_dim" not in state:
140
+ state["_qkv_same_embed_dim"] = True
141
+
142
+ super(MultiheadAttention, self).__setstate__(state)
143
+
144
+ def forward(
145
+ self,
146
+ query: Tensor,
147
+ key: Tensor,
148
+ value: Tensor,
149
+ key_padding_mask: Optional[Tensor] = None,
150
+ need_weights: bool = True,
151
+ attn_mask: Optional[Tensor] = None,
152
+ average_attn_weights: bool = True,
153
+ cache=None,
154
+ ) -> Tuple[Tensor, Optional[Tensor]]:
155
+ any_nested = query.is_nested or key.is_nested or value.is_nested
156
+ query = key = value = query.transpose(1, 0)
157
+ attn_output = multi_head_attention_forward_patched(
158
+ query,
159
+ key,
160
+ value,
161
+ self.embed_dim,
162
+ self.num_heads,
163
+ self.in_proj_weight,
164
+ self.in_proj_bias,
165
+ self.bias_k,
166
+ self.bias_v,
167
+ self.add_zero_attn,
168
+ self.dropout,
169
+ self.out_proj.weight,
170
+ self.out_proj.bias,
171
+ training=self.training,
172
+ key_padding_mask=key_padding_mask,
173
+ need_weights=need_weights,
174
+ attn_mask=attn_mask,
175
+ average_attn_weights=average_attn_weights,
176
+ cache=cache,
177
+ )
178
+ return attn_output.transpose(1, 0)
AR/modules/embedding.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
2
+ import math
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+
8
+ class TokenEmbedding(nn.Module):
9
+ def __init__(
10
+ self,
11
+ embedding_dim: int,
12
+ vocab_size: int,
13
+ dropout: float = 0.0,
14
+ ):
15
+ super().__init__()
16
+
17
+ self.vocab_size = vocab_size
18
+ self.embedding_dim = embedding_dim
19
+
20
+ self.dropout = torch.nn.Dropout(p=dropout)
21
+ self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
22
+
23
+ @property
24
+ def weight(self) -> torch.Tensor:
25
+ return self.word_embeddings.weight
26
+
27
+ def embedding(self, index: int) -> torch.Tensor:
28
+ return self.word_embeddings.weight[index : index + 1]
29
+
30
+ def forward(self, x: torch.Tensor):
31
+ x = self.word_embeddings(x)
32
+ x = self.dropout(x)
33
+ return x
34
+
35
+
36
+ class SinePositionalEmbedding(nn.Module):
37
+ def __init__(
38
+ self,
39
+ embedding_dim: int,
40
+ dropout: float = 0.0,
41
+ scale: bool = False,
42
+ alpha: bool = False,
43
+ ):
44
+ super().__init__()
45
+ self.embedding_dim = embedding_dim
46
+ self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
47
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
48
+ self.dropout = torch.nn.Dropout(p=dropout)
49
+
50
+ self.reverse = False
51
+ self.pe = None
52
+ self.extend_pe(torch.tensor(0.0).expand(1, 4000))
53
+
54
+ def extend_pe(self, x):
55
+ """Reset the positional encodings."""
56
+ if self.pe is not None:
57
+ if self.pe.size(1) >= x.size(1):
58
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
59
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
60
+ return
61
+ pe = torch.zeros(x.size(1), self.embedding_dim)
62
+ if self.reverse:
63
+ position = torch.arange(
64
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
65
+ ).unsqueeze(1)
66
+ else:
67
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
68
+ div_term = torch.exp(
69
+ torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
70
+ * -(math.log(10000.0) / self.embedding_dim)
71
+ )
72
+ pe[:, 0::2] = torch.sin(position * div_term)
73
+ pe[:, 1::2] = torch.cos(position * div_term)
74
+ pe = pe.unsqueeze(0)
75
+ self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ self.extend_pe(x)
79
+ output = x.unsqueeze(-1) if x.ndim == 2 else x
80
+ output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
81
+ return self.dropout(output)
AR/modules/embedding_onnx.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
2
+ import math
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+
8
+ class TokenEmbedding(nn.Module):
9
+ def __init__(
10
+ self,
11
+ embedding_dim: int,
12
+ vocab_size: int,
13
+ dropout: float = 0.0,
14
+ ):
15
+ super().__init__()
16
+
17
+ self.vocab_size = vocab_size
18
+ self.embedding_dim = embedding_dim
19
+
20
+ self.dropout = torch.nn.Dropout(p=dropout)
21
+ self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
22
+
23
+ @property
24
+ def weight(self) -> torch.Tensor:
25
+ return self.word_embeddings.weight
26
+
27
+ def embedding(self, index: int) -> torch.Tensor:
28
+ return self.word_embeddings.weight[index : index + 1]
29
+
30
+ def forward(self, x: torch.Tensor):
31
+ x = self.word_embeddings(x)
32
+ x = self.dropout(x)
33
+ return x
34
+
35
+
36
+ class SinePositionalEmbedding(nn.Module):
37
+ def __init__(
38
+ self,
39
+ embedding_dim: int,
40
+ dropout: float = 0.0,
41
+ scale: bool = False,
42
+ alpha: bool = False,
43
+ ):
44
+ super().__init__()
45
+ self.embedding_dim = embedding_dim
46
+ self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
47
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
48
+ self.dropout = torch.nn.Dropout(p=dropout)
49
+ self.reverse = False
50
+ self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
51
+
52
+ def extend_pe(self, x):
53
+ position = torch.cumsum(torch.ones_like(x[:,:,0]), dim=1).transpose(0, 1)
54
+ scpe = (position * self.div_term).unsqueeze(0)
55
+ pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0)
56
+ pe = pe.contiguous().view(1, -1, self.embedding_dim)
57
+ return pe
58
+
59
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ pe = self.extend_pe(x)
61
+ output = x.unsqueeze(-1) if x.ndim == 2 else x
62
+ output = output * self.x_scale + self.alpha * pe
63
+ return self.dropout(output)
AR/modules/lr_schedulers.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/modules/lr_schedulers.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import math
4
+
5
+ import torch
6
+ from matplotlib import pyplot as plt
7
+ from torch import nn
8
+ from torch.optim import Adam
9
+
10
+
11
+ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
12
+ """
13
+ Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers.
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ optimizer,
19
+ init_lr,
20
+ peak_lr,
21
+ end_lr,
22
+ warmup_steps=10000,
23
+ total_steps=400000,
24
+ current_step=0,
25
+ ):
26
+ self.init_lr = init_lr
27
+ self.peak_lr = peak_lr
28
+ self.end_lr = end_lr
29
+ self.optimizer = optimizer
30
+ self._warmup_rate = (peak_lr - init_lr) / warmup_steps
31
+ self._decay_rate = (end_lr - peak_lr) / (total_steps - warmup_steps)
32
+ self._current_step = current_step
33
+ self.lr = init_lr
34
+ self.warmup_steps = warmup_steps
35
+ self.total_steps = total_steps
36
+ self._last_lr = [self.lr]
37
+
38
+ def set_lr(self, lr):
39
+ self._last_lr = [g["lr"] for g in self.optimizer.param_groups]
40
+ for g in self.optimizer.param_groups:
41
+ # g['lr'] = lr
42
+ g["lr"] = self.end_lr ###锁定用线性
43
+
44
+ def step(self):
45
+ if self._current_step < self.warmup_steps:
46
+ lr = self.init_lr + self._warmup_rate * self._current_step
47
+
48
+ elif self._current_step > self.total_steps:
49
+ lr = self.end_lr
50
+
51
+ else:
52
+ decay_ratio = (self._current_step - self.warmup_steps) / (
53
+ self.total_steps - self.warmup_steps
54
+ )
55
+ if decay_ratio < 0.0 or decay_ratio > 1.0:
56
+ raise RuntimeError(
57
+ "Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
58
+ )
59
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
60
+ lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
61
+
62
+ self.lr = lr = self.end_lr = 0.002 ###锁定用线性###不听话,直接锁定!
63
+ self.set_lr(lr)
64
+ self.lr = lr
65
+ self._current_step += 1
66
+ return self.lr
67
+
68
+
69
+ if __name__ == "__main__":
70
+ m = nn.Linear(10, 10)
71
+ opt = Adam(m.parameters(), lr=1e-4)
72
+ s = WarmupCosineLRSchedule(
73
+ opt, 1e-6, 2e-4, 1e-6, warmup_steps=2000, total_steps=20000, current_step=0
74
+ )
75
+ lrs = []
76
+ for i in range(25000):
77
+ s.step()
78
+ lrs.append(s.lr)
79
+ print(s.lr)
80
+
81
+ plt.plot(lrs)
82
+ plt.plot(range(0, 25000), lrs)
83
+ plt.show()
AR/modules/optim.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import contextlib
17
+ import logging
18
+ from collections import defaultdict
19
+ from typing import List
20
+ from typing import Tuple
21
+
22
+ import torch
23
+ from torch import Tensor
24
+ from torch.optim import Optimizer
25
+
26
+
27
+ class BatchedOptimizer(Optimizer):
28
+ """
29
+ This class adds to class Optimizer the capability to optimize parameters in batches:
30
+ it will stack the parameters and their grads for you so the optimizer can work
31
+ on tensors with an extra leading dimension. This is intended for speed with GPUs,
32
+ as it reduces the number of kernels launched in the optimizer.
33
+
34
+ Args:
35
+ params:
36
+ """
37
+
38
+ def __init__(self, params, defaults):
39
+ super(BatchedOptimizer, self).__init__(params, defaults)
40
+
41
+ @contextlib.contextmanager
42
+ def batched_params(self, param_group, group_params_names):
43
+ """
44
+ This function returns (technically, yields) a list of
45
+ of tuples (p, state), where
46
+ p is a `fake` parameter that is stacked (over axis 0) from real parameters
47
+ that share the same shape, and its gradient is also stacked;
48
+ `state` is the state corresponding to this batch of parameters
49
+ (it will be physically located in the "state" for one of the real
50
+ parameters, the last one that has any particular shape and dtype).
51
+
52
+ This function is decorated as a context manager so that it can
53
+ write parameters back to their "real" locations.
54
+
55
+ The idea is, instead of doing:
56
+ <code>
57
+ for p in group["params"]:
58
+ state = self.state[p]
59
+ ...
60
+ </code>
61
+ you can do:
62
+ <code>
63
+ with self.batched_params(group["params"]) as batches:
64
+ for p, state, p_names in batches:
65
+ ...
66
+ </code>
67
+
68
+ Args:
69
+ group: a parameter group, which is a list of parameters; should be
70
+ one of self.param_groups.
71
+ group_params_names: name for each parameter in group,
72
+ which is List[str].
73
+ """
74
+ batches = defaultdict(
75
+ list
76
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
77
+ batches_names = defaultdict(
78
+ list
79
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
80
+
81
+ assert len(param_group) == len(group_params_names)
82
+ for p, named_p in zip(param_group, group_params_names):
83
+ key = (str(p.dtype), *p.shape)
84
+ batches[key].append(p)
85
+ batches_names[key].append(named_p)
86
+
87
+ batches_names_keys = list(batches_names.keys())
88
+ sorted_idx = sorted(
89
+ range(len(batches_names)), key=lambda i: batches_names_keys[i])
90
+ batches_names = [
91
+ batches_names[batches_names_keys[idx]] for idx in sorted_idx
92
+ ]
93
+ batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
94
+
95
+ stacked_params_dict = dict()
96
+
97
+ # turn batches into a list, in deterministic order.
98
+ # tuples will contain tuples of (stacked_param, state, stacked_params_names),
99
+ # one for each batch in `batches`.
100
+ tuples = []
101
+
102
+ for batch, batch_names in zip(batches, batches_names):
103
+ p = batch[0]
104
+ # we arbitrarily store the state in the
105
+ # state corresponding to the 1st parameter in the
106
+ # group. class Optimizer will take care of saving/loading state.
107
+ state = self.state[p]
108
+ p_stacked = torch.stack(batch)
109
+ grad = torch.stack([
110
+ torch.zeros_like(p) if p.grad is None else p.grad for p in batch
111
+ ])
112
+ p_stacked.grad = grad
113
+ stacked_params_dict[key] = p_stacked
114
+ tuples.append((p_stacked, state, batch_names))
115
+
116
+ yield tuples # <-- calling code will do the actual optimization here!
117
+
118
+ for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
119
+ for i, p in enumerate(batch): # batch is list of Parameter
120
+ p.copy_(stacked_params[i])
121
+
122
+
123
+ class ScaledAdam(BatchedOptimizer):
124
+ """
125
+ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
126
+ proportional to the norm of that parameter; and also learn the scale of the parameter,
127
+ in log space, subject to upper and lower limits (as if we had factored each parameter as
128
+ param = underlying_param * log_scale.exp())
129
+
130
+
131
+ Args:
132
+ params: The parameters or param_groups to optimize (like other Optimizer subclasses)
133
+ lr: The learning rate. We will typically use a learning rate schedule that starts
134
+ at 0.03 and decreases over time, i.e. much higher than other common
135
+ optimizers.
136
+ clipping_scale: (e.g. 2.0)
137
+ A scale for gradient-clipping: if specified, the normalized gradients
138
+ over the whole model will be clipped to have 2-norm equal to
139
+ `clipping_scale` times the median 2-norm over the most recent period
140
+ of `clipping_update_period` minibatches. By "normalized gradients",
141
+ we mean after multiplying by the rms parameter value for this tensor
142
+ [for non-scalars]; this is appropriate because our update is scaled
143
+ by this quantity.
144
+ betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
145
+ Must satisfy 0 < beta <= beta2 < 1.
146
+ scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
147
+ scale of each parameter tensor and scalar parameters of the mode..
148
+ If each parameter were decomposed
149
+ as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
150
+ would be a the scaling factor on the learning rate of p_scale.
151
+ eps: A general-purpose epsilon to prevent division by zero
152
+ param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
153
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
154
+ parameter tensor to be >= this value)
155
+ param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
156
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
157
+ parameter tensor to be <= this value)
158
+ scalar_max: Maximum absolute value for scalar parameters (applicable if your
159
+ model has any parameters with numel() == 1).
160
+ size_update_period: The periodicity, in steps, with which we update the size (scale)
161
+ of the parameter tensor. This is provided to save a little time
162
+ in the update.
163
+ clipping_update_period: if clipping_scale is specified, this is the period
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ params,
169
+ lr=3e-02,
170
+ clipping_scale=None,
171
+ betas=(0.9, 0.98),
172
+ scalar_lr_scale=0.1,
173
+ eps=1.0e-08,
174
+ param_min_rms=1.0e-05,
175
+ param_max_rms=3.0,
176
+ scalar_max=10.0,
177
+ size_update_period=4,
178
+ clipping_update_period=100,
179
+ parameters_names=None,
180
+ show_dominant_parameters=True, ):
181
+
182
+ assert parameters_names is not None, (
183
+ "Please prepare parameters_names,"
184
+ "which is a List[List[str]]. Each List[str] is for a group"
185
+ "and each str is for a parameter")
186
+ defaults = dict(
187
+ lr=lr,
188
+ clipping_scale=clipping_scale,
189
+ betas=betas,
190
+ scalar_lr_scale=scalar_lr_scale,
191
+ eps=eps,
192
+ param_min_rms=param_min_rms,
193
+ param_max_rms=param_max_rms,
194
+ scalar_max=scalar_max,
195
+ size_update_period=size_update_period,
196
+ clipping_update_period=clipping_update_period, )
197
+
198
+ super(ScaledAdam, self).__init__(params, defaults)
199
+ assert len(self.param_groups) == len(parameters_names)
200
+ self.parameters_names = parameters_names
201
+ self.show_dominant_parameters = show_dominant_parameters
202
+
203
+ def __setstate__(self, state):
204
+ super(ScaledAdam, self).__setstate__(state)
205
+
206
+ @torch.no_grad()
207
+ def step(self, closure=None):
208
+ """Performs a single optimization step.
209
+
210
+ Arguments:
211
+ closure (callable, optional): A closure that reevaluates the model
212
+ and returns the loss.
213
+ """
214
+ loss = None
215
+ if closure is not None:
216
+ with torch.enable_grad():
217
+ loss = closure()
218
+
219
+ batch = True
220
+
221
+ for group, group_params_names in zip(self.param_groups,
222
+ self.parameters_names):
223
+
224
+ with self.batched_params(group["params"],
225
+ group_params_names) as batches:
226
+
227
+ # batches is list of pairs (stacked_param, state). stacked_param is like
228
+ # a regular parameter, and will have a .grad, but the 1st dim corresponds to
229
+ # a stacking dim, it is not a real dim.
230
+
231
+ if (len(batches[0][1]) ==
232
+ 0): # if len(first state) == 0: not yet initialized
233
+ clipping_scale = 1
234
+ else:
235
+ clipping_scale = self._get_clipping_scale(group, batches)
236
+
237
+ for p, state, _ in batches:
238
+ # Perform optimization step.
239
+ # grad is not going to be None, we handled that when creating the batches.
240
+ grad = p.grad
241
+ if grad.is_sparse:
242
+ raise RuntimeError(
243
+ "ScaledAdam optimizer does not support sparse gradients"
244
+ )
245
+ # State initialization
246
+ if len(state) == 0:
247
+ self._init_state(group, p, state)
248
+
249
+ self._step_one_batch(group, p, state, clipping_scale)
250
+
251
+ return loss
252
+
253
+ def _init_state(self, group: dict, p: Tensor, state: dict):
254
+ """
255
+ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
256
+ is actually the batch dimension, corresponding to batched-together
257
+ parameters of a given shape.
258
+
259
+
260
+ Args:
261
+ group: Dict to look up configuration values.
262
+ p: The parameter that we are initializing the state for
263
+ state: Dict from string to whatever state we are initializing
264
+ """
265
+ size_update_period = group["size_update_period"]
266
+
267
+ state["step"] = 0
268
+
269
+ kwargs = {"device": p.device, "dtype": p.dtype}
270
+
271
+ # 'delta' implements conventional momentum. There are
272
+ # several different kinds of update going on, so rather than
273
+ # compute "exp_avg" like in Adam, we store and decay a
274
+ # parameter-change "delta", which combines all forms of
275
+ # update. this is equivalent to how it's done in Adam,
276
+ # except for the first few steps.
277
+ state["delta"] = torch.zeros_like(
278
+ p, memory_format=torch.preserve_format)
279
+
280
+ batch_size = p.shape[0]
281
+ numel = p.numel() // batch_size
282
+ numel = p.numel()
283
+
284
+ if numel > 1:
285
+ # "param_rms" just periodically records the scalar root-mean-square value of
286
+ # the parameter tensor.
287
+ # it has a shape like (batch_size, 1, 1, 1, 1)
288
+ param_rms = (
289
+ (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
290
+ state["param_rms"] = param_rms
291
+
292
+ state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
293
+ state["scale_grads"] = torch.zeros(size_update_period,
294
+ *param_rms.shape, **kwargs)
295
+
296
+ # exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
297
+ state["exp_avg_sq"] = torch.zeros_like(
298
+ p, memory_format=torch.preserve_format)
299
+
300
+ def _get_clipping_scale(self,
301
+ group: dict,
302
+ tuples: List[Tuple[Tensor, dict, List[str]]]
303
+ ) -> float:
304
+ """
305
+ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
306
+ by this amount before applying the rest of the update.
307
+
308
+ Args:
309
+ group: the parameter group, an item in self.param_groups
310
+ tuples: a list of tuples of (param, state, param_names)
311
+ where param is a batched set of parameters,
312
+ with a .grad (1st dim is batch dim)
313
+ and state is the state-dict where optimization parameters are kept.
314
+ param_names is a List[str] while each str is name for a parameter
315
+ in batched set of parameters "param".
316
+ """
317
+ assert len(tuples) >= 1
318
+ clipping_scale = group["clipping_scale"]
319
+ (first_p, first_state, _) = tuples[0]
320
+ step = first_state["step"]
321
+ if clipping_scale is None or step == 0:
322
+ # no clipping. return early on step == 0 because the other
323
+ # parameters' state won't have been initialized yet.
324
+ return 1.0
325
+ clipping_update_period = group["clipping_update_period"]
326
+
327
+ tot_sumsq = torch.tensor(0.0, device=first_p.device)
328
+ for (p, state, param_names) in tuples:
329
+ grad = p.grad
330
+ if grad.is_sparse:
331
+ raise RuntimeError(
332
+ "ScaledAdam optimizer does not support sparse gradients")
333
+ if p.numel() == p.shape[0]: # a batch of scalars
334
+ tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
335
+ else:
336
+ tot_sumsq += ((grad * state["param_rms"])**2).sum()
337
+
338
+ tot_norm = tot_sumsq.sqrt()
339
+ if "model_norms" not in first_state:
340
+ first_state["model_norms"] = torch.zeros(
341
+ clipping_update_period, device=p.device)
342
+ first_state["model_norms"][step % clipping_update_period] = tot_norm
343
+
344
+ if step % clipping_update_period == 0:
345
+ # Print some stats.
346
+ # We don't reach here if step == 0 because we would have returned
347
+ # above.
348
+ sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
349
+ quartiles = []
350
+ for n in range(0, 5):
351
+ index = min(
352
+ clipping_update_period - 1,
353
+ (clipping_update_period // 4) * n, )
354
+ quartiles.append(sorted_norms[index].item())
355
+
356
+ median = quartiles[2]
357
+ threshold = clipping_scale * median
358
+ first_state["model_norm_threshold"] = threshold
359
+ percent_clipped = (first_state["num_clipped"] * 100.0 /
360
+ clipping_update_period
361
+ if "num_clipped" in first_state else 0.0)
362
+ first_state["num_clipped"] = 0
363
+ quartiles = " ".join(["%.3e" % x for x in quartiles])
364
+ logging.info(
365
+ f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
366
+ f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
367
+ )
368
+
369
+ if step < clipping_update_period:
370
+ return 1.0 # We have not yet estimated a norm to clip to.
371
+ else:
372
+ try:
373
+ model_norm_threshold = first_state["model_norm_threshold"]
374
+ except KeyError:
375
+ logging.info(
376
+ "Warning: model_norm_threshold not in state: possibly "
377
+ "you changed config when restarting, adding clipping_scale option?"
378
+ )
379
+ return 1.0
380
+ ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
381
+ if ans < 1.0:
382
+ first_state["num_clipped"] += 1
383
+ if ans < 0.1:
384
+ logging.warn(
385
+ f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
386
+ )
387
+ if self.show_dominant_parameters:
388
+ assert p.shape[0] == len(param_names)
389
+ self._show_gradient_dominating_parameter(tuples, tot_sumsq)
390
+ return ans
391
+
392
+ def _show_gradient_dominating_parameter(
393
+ self, tuples: List[Tuple[Tensor, dict, List[str]]],
394
+ tot_sumsq: Tensor):
395
+ """
396
+ Show information of parameter wihch dominanting tot_sumsq.
397
+
398
+ Args:
399
+ tuples: a list of tuples of (param, state, param_names)
400
+ where param is a batched set of parameters,
401
+ with a .grad (1st dim is batch dim)
402
+ and state is the state-dict where optimization parameters are kept.
403
+ param_names is a List[str] while each str is name for a parameter
404
+ in batched set of parameters "param".
405
+ tot_sumsq: sumsq of all parameters. Though it's could be calculated
406
+ from tuples, we still pass it to save some time.
407
+ """
408
+ all_sumsq_orig = {}
409
+ for (p, state, batch_param_names) in tuples:
410
+ # p is a stacked batch parameters.
411
+ batch_grad = p.grad
412
+ if p.numel() == p.shape[0]: # a batch of scalars
413
+ batch_sumsq_orig = batch_grad**2
414
+ # Dummpy values used by following `zip` statement.
415
+ batch_rms_orig = torch.ones(p.shape[0])
416
+ else:
417
+ batch_rms_orig = state["param_rms"]
418
+ batch_sumsq_orig = ((batch_grad * batch_rms_orig)**2).sum(
419
+ dim=list(range(1, batch_grad.ndim)))
420
+
421
+ for name, sumsq_orig, rms, grad in zip(batch_param_names,
422
+ batch_sumsq_orig,
423
+ batch_rms_orig, batch_grad):
424
+
425
+ proportion_orig = sumsq_orig / tot_sumsq
426
+ all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
427
+
428
+ assert torch.isclose(
429
+ sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
430
+ torch.tensor(1.0), )
431
+ sorted_by_proportion = {
432
+ k: v
433
+ for k, v in sorted(
434
+ all_sumsq_orig.items(),
435
+ key=lambda item: item[1][0],
436
+ reverse=True, )
437
+ }
438
+ dominant_param_name = next(iter(sorted_by_proportion))
439
+ (dominant_proportion, dominant_sumsq, dominant_rms,
440
+ dominant_grad, ) = sorted_by_proportion[dominant_param_name]
441
+ logging.info(f"Parameter Dominanting tot_sumsq {dominant_param_name}"
442
+ f" with proportion {dominant_proportion:.2f},"
443
+ f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
444
+ f"={dominant_sumsq:.3e},"
445
+ f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
446
+ f" orig_rms_sq={(dominant_rms**2).item():.3e}")
447
+
448
+ def _step_one_batch(self,
449
+ group: dict,
450
+ p: Tensor,
451
+ state: dict,
452
+ clipping_scale: float):
453
+ """
454
+ Do the step for one parameter, which is actually going to be a batch of
455
+ `real` parameters, with dim 0 as the batch dim.
456
+ Args:
457
+ group: dict to look up configuration values
458
+ p: parameter to update (actually multiple parameters stacked together
459
+ as a batch)
460
+ state: state-dict for p, to look up the optimizer state
461
+ """
462
+ lr = group["lr"]
463
+ size_update_period = group["size_update_period"]
464
+ beta1 = group["betas"][0]
465
+
466
+ grad = p.grad
467
+ if clipping_scale != 1.0:
468
+ grad = grad * clipping_scale
469
+ step = state["step"]
470
+ delta = state["delta"]
471
+
472
+ delta.mul_(beta1)
473
+ batch_size = p.shape[0]
474
+ numel = p.numel() // batch_size
475
+ if numel > 1:
476
+ # Update the size/scale of p, and set param_rms
477
+ scale_grads = state["scale_grads"]
478
+ scale_grads[step % size_update_period] = (p * grad).sum(
479
+ dim=list(range(1, p.ndim)), keepdim=True)
480
+ if step % size_update_period == size_update_period - 1:
481
+ param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
482
+ param_rms.copy_((p**2)
483
+ .mean(dim=list(range(1, p.ndim)), keepdim=True)
484
+ .sqrt())
485
+ if step > 0:
486
+ # self._size_update() learns the overall scale on the
487
+ # parameter, by shrinking or expanding it.
488
+ self._size_update(group, scale_grads, p, state)
489
+
490
+ if numel == 1:
491
+ # For parameters with 1 element we just use regular Adam.
492
+ # Updates delta.
493
+ self._step_scalar(group, p, state)
494
+ else:
495
+ self._step(group, p, state)
496
+
497
+ state["step"] = step + 1
498
+
499
+ def _size_update(self,
500
+ group: dict,
501
+ scale_grads: Tensor,
502
+ p: Tensor,
503
+ state: dict) -> None:
504
+ """
505
+ Called only where p.numel() > 1, this updates the scale of the parameter.
506
+ If we imagine: p = underlying_param * scale.exp(), and we are doing
507
+ gradient descent on underlying param and on scale, this function does the update
508
+ on `scale`.
509
+
510
+ Args:
511
+ group: dict to look up configuration values
512
+ scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
513
+ grads w.r.t. the scales.
514
+ p: The parameter to update
515
+ state: The state-dict of p
516
+ """
517
+
518
+ param_rms = state["param_rms"]
519
+ beta1, beta2 = group["betas"]
520
+ size_lr = group["lr"] * group["scalar_lr_scale"]
521
+ param_min_rms = group["param_min_rms"]
522
+ param_max_rms = group["param_max_rms"]
523
+ eps = group["eps"]
524
+ step = state["step"]
525
+ batch_size = p.shape[0]
526
+
527
+ size_update_period = scale_grads.shape[0]
528
+ # correct beta2 for the size update period: we will have
529
+ # faster decay at this level.
530
+ beta2_corr = beta2**size_update_period
531
+
532
+ scale_exp_avg_sq = state[
533
+ "scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
534
+ scale_exp_avg_sq.mul_(beta2_corr).add_(
535
+ (scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
536
+ alpha=1 - beta2_corr, ) # shape is (batch_size, 1, 1, ...)
537
+
538
+ # The 1st time we reach here is when size_step == 1.
539
+ size_step = (step + 1) // size_update_period
540
+ bias_correction2 = 1 - beta2_corr**size_step
541
+ # we don't bother with bias_correction1; this will help prevent divergence
542
+ # at the start of training.
543
+
544
+ denom = scale_exp_avg_sq.sqrt() + eps
545
+
546
+ scale_step = (-size_lr * (bias_correction2**0.5) *
547
+ scale_grads.sum(dim=0) / denom)
548
+
549
+ is_too_small = param_rms < param_min_rms
550
+ is_too_large = param_rms > param_max_rms
551
+
552
+ # when the param gets too small, just don't shrink it any further.
553
+ scale_step.masked_fill_(is_too_small, 0.0)
554
+ # when it gets too large, stop it from getting any larger.
555
+ scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
556
+ delta = state["delta"]
557
+ # the factor of (1-beta1) relates to momentum.
558
+ delta.add_(p * scale_step, alpha=(1 - beta1))
559
+
560
+ def _step(self, group: dict, p: Tensor, state: dict):
561
+ """
562
+ This function does the core update of self.step(), in the case where the members of
563
+ the batch have more than 1 element.
564
+
565
+ Args:
566
+ group: A dict which will be used to look up configuration values
567
+ p: The parameter to be updated
568
+ grad: The grad of p
569
+ state: The state-dict corresponding to parameter p
570
+
571
+ This function modifies p.
572
+ """
573
+ grad = p.grad
574
+ lr = group["lr"]
575
+ beta1, beta2 = group["betas"]
576
+ eps = group["eps"]
577
+ param_min_rms = group["param_min_rms"]
578
+ step = state["step"]
579
+
580
+ exp_avg_sq = state["exp_avg_sq"]
581
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
582
+
583
+ this_step = state["step"] - (state["zero_step"]
584
+ if "zero_step" in state else 0)
585
+ bias_correction2 = 1 - beta2**(this_step + 1)
586
+ if bias_correction2 < 0.99:
587
+ # note: not in-place.
588
+ exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
589
+
590
+ denom = exp_avg_sq.sqrt()
591
+ denom += eps
592
+ grad = grad / denom
593
+
594
+ alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
595
+
596
+ delta = state["delta"]
597
+ delta.add_(grad * alpha)
598
+ p.add_(delta)
599
+
600
+ def _step_scalar(self, group: dict, p: Tensor, state: dict):
601
+ """
602
+ A simplified form of the core update for scalar tensors, where we cannot get a good
603
+ estimate of the parameter rms.
604
+ """
605
+ beta1, beta2 = group["betas"]
606
+ scalar_max = group["scalar_max"]
607
+ eps = group["eps"]
608
+ lr = group["lr"] * group["scalar_lr_scale"]
609
+ grad = p.grad
610
+
611
+ exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
612
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
613
+
614
+ # bias_correction2 is like in Adam. Don't bother with bias_correction1;
615
+ # slower update at the start will help stability anyway.
616
+ bias_correction2 = 1 - beta2**(state["step"] + 1)
617
+ denom = (exp_avg_sq / bias_correction2).sqrt() + eps
618
+
619
+ delta = state["delta"]
620
+ delta.add_(grad / denom, alpha=-lr * (1 - beta1))
621
+ p.clamp_(min=-scalar_max, max=scalar_max)
622
+ p.add_(delta)
AR/modules/patched_mha_with_cache.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn.functional import *
2
+ from torch.nn.functional import (
3
+ _mha_shape_check,
4
+ _canonical_mask,
5
+ _none_or_dtype,
6
+ _in_projection_packed,
7
+ )
8
+ from torch.nn import functional as F
9
+ import torch
10
+ # Tensor = torch.Tensor
11
+ # from typing import Callable, List, Optional, Tuple, Union
12
+
13
+
14
+ def multi_head_attention_forward_patched(
15
+ query: Tensor,
16
+ key: Tensor,
17
+ value: Tensor,
18
+ embed_dim_to_check: int,
19
+ num_heads: int,
20
+ in_proj_weight: Optional[Tensor],
21
+ in_proj_bias: Optional[Tensor],
22
+ bias_k: Optional[Tensor],
23
+ bias_v: Optional[Tensor],
24
+ add_zero_attn: bool,
25
+ dropout_p: float,
26
+ out_proj_weight: Tensor,
27
+ out_proj_bias: Optional[Tensor],
28
+ training: bool = True,
29
+ key_padding_mask: Optional[Tensor] = None,
30
+ need_weights: bool = True,
31
+ attn_mask: Optional[Tensor] = None,
32
+ use_separate_proj_weight: bool = False,
33
+ q_proj_weight: Optional[Tensor] = None,
34
+ k_proj_weight: Optional[Tensor] = None,
35
+ v_proj_weight: Optional[Tensor] = None,
36
+ static_k: Optional[Tensor] = None,
37
+ static_v: Optional[Tensor] = None,
38
+ average_attn_weights: bool = True,
39
+ is_causal: bool = False,
40
+ cache=None,
41
+ ) -> Tuple[Tensor, Optional[Tensor]]:
42
+ r"""
43
+ Args:
44
+ query, key, value: map a query and a set of key-value pairs to an output.
45
+ See "Attention Is All You Need" for more details.
46
+ embed_dim_to_check: total dimension of the model.
47
+ num_heads: parallel attention heads.
48
+ in_proj_weight, in_proj_bias: input projection weight and bias.
49
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
50
+ add_zero_attn: add a new batch of zeros to the key and
51
+ value sequences at dim=1.
52
+ dropout_p: probability of an element to be zeroed.
53
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
54
+ training: apply dropout if is ``True``.
55
+ key_padding_mask: if provided, specified padding elements in the key will
56
+ be ignored by the attention. This is an binary mask. When the value is True,
57
+ the corresponding value on the attention layer will be filled with -inf.
58
+ need_weights: output attn_output_weights.
59
+ Default: `True`
60
+ Note: `needs_weight` defaults to `True`, but should be set to `False`
61
+ For best performance when attention weights are not nedeeded.
62
+ *Setting needs_weights to `True`
63
+ leads to a significant performance degradation.*
64
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
65
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
66
+ is_causal: If specified, applies a causal mask as attention mask, and ignores
67
+ attn_mask for computing scaled dot product attention.
68
+ Default: ``False``.
69
+ .. warning::
70
+ is_causal is provides a hint that the attn_mask is the
71
+ causal mask.Providing incorrect hints can result in
72
+ incorrect execution, including forward and backward
73
+ compatibility.
74
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
75
+ and value in different forms. If false, in_proj_weight will be used, which is
76
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
77
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
78
+ static_k, static_v: static key and value used for attention operators.
79
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
80
+ Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
81
+ when ``need_weights=True.``. Default: True
82
+
83
+
84
+ Shape:
85
+ Inputs:
86
+ - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
87
+ the embedding dimension.
88
+ - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
89
+ the embedding dimension.
90
+ - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
91
+ the embedding dimension.
92
+ - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
93
+ If a FloatTensor is provided, it will be directly added to the value.
94
+ If a BoolTensor is provided, the positions with the
95
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
96
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
97
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
98
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
99
+ positions. If a BoolTensor is provided, positions with ``True``
100
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
101
+ is provided, it will be added to the attention weight.
102
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
103
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
104
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
105
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
106
+
107
+ Outputs:
108
+ - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
109
+ E is the embedding dimension.
110
+ - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
111
+ attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
112
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
113
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
114
+ head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
115
+ """
116
+ tens_ops = (
117
+ query,
118
+ key,
119
+ value,
120
+ in_proj_weight,
121
+ in_proj_bias,
122
+ bias_k,
123
+ bias_v,
124
+ out_proj_weight,
125
+ out_proj_bias,
126
+ )
127
+ if has_torch_function(tens_ops):
128
+ return handle_torch_function(
129
+ multi_head_attention_forward,
130
+ tens_ops,
131
+ query,
132
+ key,
133
+ value,
134
+ embed_dim_to_check,
135
+ num_heads,
136
+ in_proj_weight,
137
+ in_proj_bias,
138
+ bias_k,
139
+ bias_v,
140
+ add_zero_attn,
141
+ dropout_p,
142
+ out_proj_weight,
143
+ out_proj_bias,
144
+ training=training,
145
+ key_padding_mask=key_padding_mask,
146
+ need_weights=need_weights,
147
+ attn_mask=attn_mask,
148
+ is_causal=is_causal,
149
+ use_separate_proj_weight=use_separate_proj_weight,
150
+ q_proj_weight=q_proj_weight,
151
+ k_proj_weight=k_proj_weight,
152
+ v_proj_weight=v_proj_weight,
153
+ static_k=static_k,
154
+ static_v=static_v,
155
+ average_attn_weights=average_attn_weights,
156
+ cache=cache,
157
+ )
158
+
159
+ is_batched = _mha_shape_check(
160
+ query, key, value, key_padding_mask, attn_mask, num_heads
161
+ )
162
+
163
+ # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
164
+ # is batched, run the computation and before returning squeeze the
165
+ # batch dimension so that the output doesn't carry this temporary batch dimension.
166
+ if not is_batched:
167
+ # unsqueeze if the input is unbatched
168
+ query = query.unsqueeze(1)
169
+ key = key.unsqueeze(1)
170
+ value = value.unsqueeze(1)
171
+ if key_padding_mask is not None:
172
+ key_padding_mask = key_padding_mask.unsqueeze(0)
173
+
174
+ # set up shape vars
175
+ tgt_len, bsz, embed_dim = query.shape
176
+ src_len, _, _ = key.shape
177
+
178
+ key_padding_mask = _canonical_mask(
179
+ mask=key_padding_mask,
180
+ mask_name="key_padding_mask",
181
+ other_type=_none_or_dtype(attn_mask),
182
+ other_name="attn_mask",
183
+ target_type=query.dtype,
184
+ )
185
+
186
+ if is_causal and attn_mask is None:
187
+ raise RuntimeError(
188
+ "Need attn_mask if specifying the is_causal hint. "
189
+ "You may use the Transformer module method "
190
+ "`generate_square_subsequent_mask` to create this mask."
191
+ )
192
+
193
+ if is_causal and key_padding_mask is None and not need_weights:
194
+ # when we have a kpm or need weights, we need attn_mask
195
+ # Otherwise, we use the is_causal hint go as is_causal
196
+ # indicator to SDPA.
197
+ attn_mask = None
198
+ else:
199
+ attn_mask = _canonical_mask(
200
+ mask=attn_mask,
201
+ mask_name="attn_mask",
202
+ other_type=None,
203
+ other_name="",
204
+ target_type=query.dtype,
205
+ check_other=False,
206
+ )
207
+
208
+ if key_padding_mask is not None:
209
+ # We have the attn_mask, and use that to merge kpm into it.
210
+ # Turn off use of is_causal hint, as the merged mask is no
211
+ # longer causal.
212
+ is_causal = False
213
+
214
+ assert (
215
+ embed_dim == embed_dim_to_check
216
+ ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
217
+ if isinstance(embed_dim, torch.Tensor):
218
+ # embed_dim can be a tensor when JIT tracing
219
+ head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
220
+ else:
221
+ head_dim = embed_dim // num_heads
222
+ assert (
223
+ head_dim * num_heads == embed_dim
224
+ ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
225
+ if use_separate_proj_weight:
226
+ # allow MHA to have different embedding dimensions when separate projection weights are used
227
+ assert (
228
+ key.shape[:2] == value.shape[:2]
229
+ ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
230
+ else:
231
+ assert (
232
+ key.shape == value.shape
233
+ ), f"key shape {key.shape} does not match value shape {value.shape}"
234
+
235
+ #
236
+ # compute in-projection
237
+ #
238
+ if not use_separate_proj_weight:
239
+ assert (
240
+ in_proj_weight is not None
241
+ ), "use_separate_proj_weight is False but in_proj_weight is None"
242
+ q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
243
+ else:
244
+ assert (
245
+ q_proj_weight is not None
246
+ ), "use_separate_proj_weight is True but q_proj_weight is None"
247
+ assert (
248
+ k_proj_weight is not None
249
+ ), "use_separate_proj_weight is True but k_proj_weight is None"
250
+ assert (
251
+ v_proj_weight is not None
252
+ ), "use_separate_proj_weight is True but v_proj_weight is None"
253
+ if in_proj_bias is None:
254
+ b_q = b_k = b_v = None
255
+ else:
256
+ b_q, b_k, b_v = in_proj_bias.chunk(3)
257
+ q, k, v = _in_projection(
258
+ query,
259
+ key,
260
+ value,
261
+ q_proj_weight,
262
+ k_proj_weight,
263
+ v_proj_weight,
264
+ b_q,
265
+ b_k,
266
+ b_v,
267
+ )
268
+ if cache != None:
269
+ if cache["first_infer"] == 1:
270
+ cache["k"][cache["stage"]] = k
271
+ # print(0,cache["k"].shape)
272
+ cache["v"][cache["stage"]] = v
273
+ else: ###12个layer每个都要留自己的cache_kv
274
+ # print(1,cache["k"].shape)
275
+ cache["k"][cache["stage"]] = torch.cat(
276
+ [cache["k"][cache["stage"]], k], 0
277
+ ) ##本来时序是1,但是proj的时候可能transpose了所以时序到0维了
278
+ cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]], v], 0)
279
+ # print(2, cache["k"].shape)
280
+ src_len = cache["k"][cache["stage"]].shape[0]
281
+ k = cache["k"][cache["stage"]]
282
+ v = cache["v"][cache["stage"]]
283
+ # if attn_mask is not None:
284
+ # attn_mask=attn_mask[-1:,]
285
+ # print(attn_mask.shape,attn_mask)
286
+ cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
287
+ # print(2333,cache)
288
+ # prep attention mask
289
+
290
+ attn_mask = _canonical_mask(
291
+ mask=attn_mask,
292
+ mask_name="attn_mask",
293
+ other_type=None,
294
+ other_name="",
295
+ target_type=q.dtype,
296
+ check_other=False,
297
+ )
298
+
299
+ if attn_mask is not None:
300
+ # ensure attn_mask's dim is 3
301
+ if attn_mask.dim() == 2:
302
+ correct_2d_size = (tgt_len, src_len)
303
+ if attn_mask.shape != correct_2d_size:
304
+ raise RuntimeError(
305
+ f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
306
+ )
307
+ attn_mask = attn_mask.unsqueeze(0)
308
+ elif attn_mask.dim() == 3:
309
+ correct_3d_size = (bsz * num_heads, tgt_len, src_len)
310
+ if attn_mask.shape != correct_3d_size:
311
+ raise RuntimeError(
312
+ f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
313
+ )
314
+ else:
315
+ raise RuntimeError(
316
+ f"attn_mask's dimension {attn_mask.dim()} is not supported"
317
+ )
318
+
319
+ # add bias along batch dimension (currently second)
320
+ if bias_k is not None and bias_v is not None:
321
+ assert static_k is None, "bias cannot be added to static key."
322
+ assert static_v is None, "bias cannot be added to static value."
323
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
324
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
325
+ if attn_mask is not None:
326
+ attn_mask = pad(attn_mask, (0, 1))
327
+ if key_padding_mask is not None:
328
+ key_padding_mask = pad(key_padding_mask, (0, 1))
329
+ else:
330
+ assert bias_k is None
331
+ assert bias_v is None
332
+
333
+ #
334
+ # reshape q, k, v for multihead attention and make em batch first
335
+ #
336
+ q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
337
+ if static_k is None:
338
+ k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
339
+ else:
340
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
341
+ assert (
342
+ static_k.size(0) == bsz * num_heads
343
+ ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
344
+ assert (
345
+ static_k.size(2) == head_dim
346
+ ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
347
+ k = static_k
348
+ if static_v is None:
349
+ v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
350
+ else:
351
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
352
+ assert (
353
+ static_v.size(0) == bsz * num_heads
354
+ ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
355
+ assert (
356
+ static_v.size(2) == head_dim
357
+ ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
358
+ v = static_v
359
+
360
+ # add zero attention along batch dimension (now first)
361
+ if add_zero_attn:
362
+ zero_attn_shape = (bsz * num_heads, 1, head_dim)
363
+ k = torch.cat(
364
+ [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
365
+ )
366
+ v = torch.cat(
367
+ [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
368
+ )
369
+ if attn_mask is not None:
370
+ attn_mask = pad(attn_mask, (0, 1))
371
+ if key_padding_mask is not None:
372
+ key_padding_mask = pad(key_padding_mask, (0, 1))
373
+
374
+ # update source sequence length after adjustments
375
+ src_len = k.size(1)
376
+
377
+ # merge key padding and attention masks
378
+ if key_padding_mask is not None:
379
+ assert key_padding_mask.shape == (
380
+ bsz,
381
+ src_len,
382
+ ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
383
+ key_padding_mask = (
384
+ key_padding_mask.view(bsz, 1, 1, src_len)
385
+ .expand(-1, num_heads, -1, -1)
386
+ .reshape(bsz * num_heads, 1, src_len)
387
+ )
388
+ if attn_mask is None:
389
+ attn_mask = key_padding_mask
390
+ else:
391
+ attn_mask = attn_mask + key_padding_mask
392
+
393
+ # adjust dropout probability
394
+ if not training:
395
+ dropout_p = 0.0
396
+
397
+ #
398
+ # (deep breath) calculate attention and out projection
399
+ #
400
+
401
+ if need_weights:
402
+ B, Nt, E = q.shape
403
+ q_scaled = q / math.sqrt(E)
404
+
405
+ assert not (
406
+ is_causal and attn_mask is None
407
+ ), "FIXME: is_causal not implemented for need_weights"
408
+
409
+ if attn_mask is not None:
410
+ attn_output_weights = torch.baddbmm(
411
+ attn_mask, q_scaled, k.transpose(-2, -1)
412
+ )
413
+ else:
414
+ attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
415
+ attn_output_weights = softmax(attn_output_weights, dim=-1)
416
+ if dropout_p > 0.0:
417
+ attn_output_weights = dropout(attn_output_weights, p=dropout_p)
418
+
419
+ attn_output = torch.bmm(attn_output_weights, v)
420
+
421
+ attn_output = (
422
+ attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
423
+ )
424
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
425
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
426
+
427
+ # optionally average attention weights over heads
428
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
429
+ if average_attn_weights:
430
+ attn_output_weights = attn_output_weights.mean(dim=1)
431
+
432
+ if not is_batched:
433
+ # squeeze the output if input was unbatched
434
+ attn_output = attn_output.squeeze(1)
435
+ attn_output_weights = attn_output_weights.squeeze(0)
436
+ return attn_output, attn_output_weights
437
+ else:
438
+ # attn_mask can be either (L,S) or (N*num_heads, L, S)
439
+ # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
440
+ # in order to match the input for SDPA of (N, num_heads, L, S)
441
+ if attn_mask is not None:
442
+ if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
443
+ attn_mask = attn_mask.unsqueeze(0)
444
+ else:
445
+ attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
446
+
447
+ q = q.view(bsz, num_heads, tgt_len, head_dim)
448
+ k = k.view(bsz, num_heads, src_len, head_dim)
449
+ v = v.view(bsz, num_heads, src_len, head_dim)
450
+
451
+ # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
452
+ attn_output = scaled_dot_product_attention(
453
+ q, k, v, attn_mask, dropout_p, is_causal
454
+ )
455
+
456
+ attn_output = (
457
+ attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
458
+ )
459
+
460
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
461
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
462
+ if not is_batched:
463
+ # squeeze the output if input was unbatched
464
+ attn_output = attn_output.squeeze(1)
465
+ return attn_output, None
AR/modules/patched_mha_with_cache_onnx.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn.functional import *
2
+ from torch.nn.functional import (
3
+ _mha_shape_check,
4
+ _canonical_mask,
5
+ _none_or_dtype,
6
+ _in_projection_packed,
7
+ )
8
+
9
+ def multi_head_attention_forward_patched(
10
+ query,
11
+ key,
12
+ value,
13
+ embed_dim_to_check: int,
14
+ num_heads: int,
15
+ in_proj_weight,
16
+ in_proj_bias: Optional[Tensor],
17
+ bias_k: Optional[Tensor],
18
+ bias_v: Optional[Tensor],
19
+ add_zero_attn: bool,
20
+ dropout_p: float,
21
+ out_proj_weight: Tensor,
22
+ out_proj_bias: Optional[Tensor],
23
+ training: bool = True,
24
+ key_padding_mask: Optional[Tensor] = None,
25
+ need_weights: bool = True,
26
+ attn_mask: Optional[Tensor] = None,
27
+ use_separate_proj_weight: bool = False,
28
+ q_proj_weight: Optional[Tensor] = None,
29
+ k_proj_weight: Optional[Tensor] = None,
30
+ v_proj_weight: Optional[Tensor] = None,
31
+ static_k: Optional[Tensor] = None,
32
+ static_v: Optional[Tensor] = None,
33
+ average_attn_weights: bool = True,
34
+ is_causal: bool = False,
35
+ cache=None,
36
+ ) -> Tuple[Tensor, Optional[Tensor]]:
37
+
38
+ # set up shape vars
39
+ _, _, embed_dim = query.shape
40
+ attn_mask = _canonical_mask(
41
+ mask=attn_mask,
42
+ mask_name="attn_mask",
43
+ other_type=None,
44
+ other_name="",
45
+ target_type=query.dtype,
46
+ check_other=False,
47
+ )
48
+ head_dim = embed_dim // num_heads
49
+
50
+ proj_qkv = linear(query, in_proj_weight, in_proj_bias)
51
+ proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
52
+ q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2]
53
+
54
+ if cache["first_infer"] == 1:
55
+ cache["k"][cache["stage"]] = k
56
+ cache["v"][cache["stage"]] = v
57
+ else:
58
+ cache["k"][cache["stage"]] = torch.cat([cache["k"][cache["stage"]][:-1], k], 0)
59
+ cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0)
60
+ k = cache["k"][cache["stage"]]
61
+ v = cache["v"][cache["stage"]]
62
+ cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
63
+
64
+ attn_mask = _canonical_mask(
65
+ mask=attn_mask,
66
+ mask_name="attn_mask",
67
+ other_type=None,
68
+ other_name="",
69
+ target_type=q.dtype,
70
+ check_other=False,
71
+ )
72
+ attn_mask = attn_mask.unsqueeze(0)
73
+
74
+ q = q.view(-1, num_heads, head_dim).transpose(0, 1)
75
+ k = k.view(-1, num_heads, head_dim).transpose(0, 1)
76
+ v = v.view(-1, num_heads, head_dim).transpose(0, 1)
77
+
78
+ dropout_p = 0.0
79
+ attn_mask = attn_mask.unsqueeze(0)
80
+ q = q.view(num_heads, -1, head_dim).unsqueeze(0)
81
+ k = k.view(num_heads, -1, head_dim).unsqueeze(0)
82
+ v = v.view(num_heads, -1, head_dim).unsqueeze(0)
83
+ attn_output = scaled_dot_product_attention(
84
+ q, k, v, attn_mask, dropout_p, is_causal
85
+ )
86
+ attn_output = (
87
+ attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
88
+ )
89
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
90
+ attn_output = attn_output.view(-1, 1, attn_output.size(1))
91
+
92
+ return attn_output
AR/modules/scaling.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import logging
17
+ import math
18
+ import random
19
+ from typing import Optional
20
+ from typing import Tuple
21
+ from typing import Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ from torch import Tensor
26
+
27
+
28
+ class DoubleSwishFunction(torch.autograd.Function):
29
+ """
30
+ double_swish(x) = x * torch.sigmoid(x-1)
31
+ This is a definition, originally motivated by its close numerical
32
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
33
+
34
+ Memory-efficient derivative computation:
35
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
36
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
37
+ Now, s'(x) = s(x) * (1-s(x)).
38
+ double_swish'(x) = x * s'(x) + s(x).
39
+ = x * s(x) * (1-s(x)) + s(x).
40
+ = double_swish(x) * (1-s(x)) + s(x)
41
+ ... so we just need to remember s(x) but not x itself.
42
+ """
43
+
44
+ @staticmethod
45
+ def forward(ctx, x: Tensor) -> Tensor:
46
+ requires_grad = x.requires_grad
47
+ x_dtype = x.dtype
48
+ if x.dtype == torch.float16:
49
+ x = x.to(torch.float32)
50
+
51
+ s = torch.sigmoid(x - 1.0)
52
+ y = x * s
53
+
54
+ if requires_grad:
55
+ deriv = y * (1 - s) + s
56
+ # notes on derivative of x * sigmoid(x - 1):
57
+ # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
58
+ # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
59
+ # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
60
+ # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
61
+ # floors), should be expectation-preserving.
62
+ floor = -0.043637
63
+ ceil = 1.2
64
+ d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
65
+ deriv
66
+ )
67
+ if __name__ == "__main__":
68
+ # for self-testing only.
69
+ assert d_scaled.min() >= 0.0
70
+ assert d_scaled.max() < 256.0
71
+ d_int = d_scaled.to(torch.uint8)
72
+ ctx.save_for_backward(d_int)
73
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
74
+ y = y.to(torch.float16)
75
+ return y
76
+
77
+ @staticmethod
78
+ def backward(ctx, y_grad: Tensor) -> Tensor:
79
+ (d,) = ctx.saved_tensors
80
+ # the same constants as used in forward pass.
81
+ floor = -0.043637
82
+ ceil = 1.2
83
+ d = d * ((ceil - floor) / 255.0) + floor
84
+ return y_grad * d
85
+
86
+
87
+ class DoubleSwish(torch.nn.Module):
88
+ def forward(self, x: Tensor) -> Tensor:
89
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
90
+ that we approximate closely with x * sigmoid(x-1).
91
+ """
92
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
93
+ return x * torch.sigmoid(x - 1.0)
94
+ return DoubleSwishFunction.apply(x)
95
+
96
+
97
+ class ActivationBalancerFunction(torch.autograd.Function):
98
+ @staticmethod
99
+ def forward(
100
+ ctx,
101
+ x: Tensor,
102
+ scale_factor: Tensor,
103
+ sign_factor: Optional[Tensor],
104
+ channel_dim: int,
105
+ ) -> Tensor:
106
+ if channel_dim < 0:
107
+ channel_dim += x.ndim
108
+ ctx.channel_dim = channel_dim
109
+ xgt0 = x > 0
110
+ if sign_factor is None:
111
+ ctx.save_for_backward(xgt0, scale_factor)
112
+ else:
113
+ ctx.save_for_backward(xgt0, scale_factor, sign_factor)
114
+ return x
115
+
116
+ @staticmethod
117
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
118
+ if len(ctx.saved_tensors) == 3:
119
+ xgt0, scale_factor, sign_factor = ctx.saved_tensors
120
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
121
+ scale_factor = scale_factor.unsqueeze(-1)
122
+ sign_factor = sign_factor.unsqueeze(-1)
123
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
124
+ else:
125
+ xgt0, scale_factor = ctx.saved_tensors
126
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
127
+ scale_factor = scale_factor.unsqueeze(-1)
128
+ factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
129
+ neg_delta_grad = x_grad.abs() * factor
130
+ return (
131
+ x_grad - neg_delta_grad,
132
+ None,
133
+ None,
134
+ None,
135
+ )
136
+
137
+
138
+ def _compute_scale_factor(
139
+ x: Tensor,
140
+ channel_dim: int,
141
+ min_abs: float,
142
+ max_abs: float,
143
+ gain_factor: float,
144
+ max_factor: float,
145
+ ) -> Tensor:
146
+ if channel_dim < 0:
147
+ channel_dim += x.ndim
148
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
149
+ x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
150
+
151
+ if min_abs == 0.0:
152
+ below_threshold = 0.0
153
+ else:
154
+ # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
155
+ # x_abs)_mean , min_abs.
156
+ below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
157
+ min=0, max=max_factor
158
+ )
159
+
160
+ above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
161
+ min=0, max=max_factor
162
+ )
163
+
164
+ return below_threshold - above_threshold
165
+
166
+
167
+ def _compute_sign_factor(
168
+ x: Tensor,
169
+ channel_dim: int,
170
+ min_positive: float,
171
+ max_positive: float,
172
+ gain_factor: float,
173
+ max_factor: float,
174
+ ) -> Tensor:
175
+ if channel_dim < 0:
176
+ channel_dim += x.ndim
177
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
178
+ proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
179
+ if min_positive == 0.0:
180
+ factor1 = 0.0
181
+ else:
182
+ # 0 if proportion_positive >= min_positive, else can be
183
+ # as large as max_factor.
184
+ factor1 = (
185
+ (min_positive - proportion_positive) * (gain_factor / min_positive)
186
+ ).clamp_(min=0, max=max_factor)
187
+
188
+ if max_positive == 1.0:
189
+ factor2 = 0.0
190
+ else:
191
+ # 0 if self.proportion_positive <= max_positive, else can be
192
+ # as large as -max_factor.
193
+ factor2 = (
194
+ (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
195
+ ).clamp_(min=0, max=max_factor)
196
+ sign_factor = factor1 - factor2
197
+ # require min_positive != 0 or max_positive != 1:
198
+ assert not isinstance(sign_factor, float)
199
+ return sign_factor
200
+
201
+
202
+ class ActivationBalancer(torch.nn.Module):
203
+ """
204
+ Modifies the backpropped derivatives of a function to try to encourage, for
205
+ each channel, that it is positive at least a proportion `threshold` of the
206
+ time. It does this by multiplying negative derivative values by up to
207
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
208
+ interpolated from 1 at the threshold to those extremal values when none
209
+ of the inputs are positive.
210
+
211
+ Args:
212
+ num_channels: the number of channels
213
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
214
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
215
+ min_positive: the minimum, per channel, of the proportion of the time
216
+ that (x > 0), below which we start to modify the derivatives.
217
+ max_positive: the maximum, per channel, of the proportion of the time
218
+ that (x > 0), above which we start to modify the derivatives.
219
+ max_factor: the maximum factor by which we modify the derivatives for
220
+ either the sign constraint or the magnitude constraint;
221
+ e.g. with max_factor=0.02, the the derivatives would be multiplied by
222
+ values in the range [0.98..1.02].
223
+ sign_gain_factor: determines the 'gain' with which we increase the
224
+ change in gradient once the constraints on min_positive and max_positive
225
+ are violated.
226
+ scale_gain_factor: determines the 'gain' with which we increase the
227
+ change in gradient once the constraints on min_abs and max_abs
228
+ are violated.
229
+ min_abs: the minimum average-absolute-value difference from the mean
230
+ value per channel, which we allow, before we start to modify
231
+ the derivatives to prevent this.
232
+ max_abs: the maximum average-absolute-value difference from the mean
233
+ value per channel, which we allow, before we start to modify
234
+ the derivatives to prevent this.
235
+ min_prob: determines the minimum probability with which we modify the
236
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
237
+ on each forward(). This is done randomly to prevent all layers
238
+ from doing it at the same time. Early in training we may use
239
+ higher probabilities than this; it will decay to this value.
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ num_channels: int,
245
+ channel_dim: int,
246
+ min_positive: float = 0.05,
247
+ max_positive: float = 0.95,
248
+ max_factor: float = 0.04,
249
+ sign_gain_factor: float = 0.01,
250
+ scale_gain_factor: float = 0.02,
251
+ min_abs: float = 0.2,
252
+ max_abs: float = 100.0,
253
+ min_prob: float = 0.1,
254
+ ):
255
+ super(ActivationBalancer, self).__init__()
256
+ self.num_channels = num_channels
257
+ self.channel_dim = channel_dim
258
+ self.min_positive = min_positive
259
+ self.max_positive = max_positive
260
+ self.max_factor = max_factor
261
+ self.min_abs = min_abs
262
+ self.max_abs = max_abs
263
+ self.min_prob = min_prob
264
+ self.sign_gain_factor = sign_gain_factor
265
+ self.scale_gain_factor = scale_gain_factor
266
+
267
+ # count measures how many times the forward() function has been called.
268
+ # We occasionally sync this to a tensor called `count`, that exists to
269
+ # make sure it is synced to disk when we load and save the model.
270
+ self.cpu_count = 0
271
+ self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
272
+
273
+ def forward(self, x: Tensor) -> Tensor:
274
+ if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing():
275
+ return _no_op(x)
276
+
277
+ count = self.cpu_count
278
+ self.cpu_count += 1
279
+
280
+ if random.random() < 0.01:
281
+ # Occasionally sync self.cpu_count with self.count.
282
+ # count affects the decay of 'prob'. don't do this on every iter,
283
+ # because syncing with the GPU is slow.
284
+ self.cpu_count = max(self.cpu_count, self.count.item())
285
+ self.count.fill_(self.cpu_count)
286
+
287
+ # the prob of doing some work exponentially decreases from 0.5 till it hits
288
+ # a floor at min_prob (==0.1, by default)
289
+ prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
290
+
291
+ if random.random() < prob:
292
+ sign_gain_factor = 0.5
293
+ if self.min_positive != 0.0 or self.max_positive != 1.0:
294
+ sign_factor = _compute_sign_factor(
295
+ x,
296
+ self.channel_dim,
297
+ self.min_positive,
298
+ self.max_positive,
299
+ gain_factor=self.sign_gain_factor / prob,
300
+ max_factor=self.max_factor,
301
+ )
302
+ else:
303
+ sign_factor = None
304
+
305
+ scale_factor = _compute_scale_factor(
306
+ x.detach(),
307
+ self.channel_dim,
308
+ min_abs=self.min_abs,
309
+ max_abs=self.max_abs,
310
+ gain_factor=self.scale_gain_factor / prob,
311
+ max_factor=self.max_factor,
312
+ )
313
+ return ActivationBalancerFunction.apply(
314
+ x,
315
+ scale_factor,
316
+ sign_factor,
317
+ self.channel_dim,
318
+ )
319
+ else:
320
+ return _no_op(x)
321
+
322
+
323
+ def BalancedDoubleSwish(
324
+ d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
325
+ ) -> nn.Sequential:
326
+ """
327
+ ActivationBalancer -> DoubleSwish
328
+ """
329
+ balancer = ActivationBalancer(
330
+ d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
331
+ )
332
+ return nn.Sequential(
333
+ balancer,
334
+ DoubleSwish(),
335
+ )
AR/modules/transformer.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py
2
+ import copy
3
+ import numbers
4
+ from functools import partial
5
+ from typing import Any
6
+ from typing import Callable
7
+ from typing import List
8
+ from typing import Optional
9
+ from typing import Tuple
10
+ from typing import Union
11
+
12
+ import torch
13
+ from AR.modules.activation import MultiheadAttention
14
+ from AR.modules.scaling import BalancedDoubleSwish
15
+ from torch import nn
16
+ from torch import Tensor
17
+ from torch.nn import functional as F
18
+
19
+ _shape_t = Union[int, List[int], torch.Size]
20
+
21
+
22
+ class LayerNorm(nn.Module):
23
+ __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
24
+ normalized_shape: Tuple[int, ...]
25
+ eps: float
26
+ elementwise_affine: bool
27
+
28
+ def __init__(
29
+ self,
30
+ normalized_shape: _shape_t,
31
+ eps: float = 1e-5,
32
+ elementwise_affine: bool = True,
33
+ device=None,
34
+ dtype=None,
35
+ ) -> None:
36
+ factory_kwargs = {"device": device, "dtype": dtype}
37
+ super(LayerNorm, self).__init__()
38
+ if isinstance(normalized_shape, numbers.Integral):
39
+ # mypy error: incompatible types in assignment
40
+ normalized_shape = (normalized_shape,) # type: ignore[assignment]
41
+ self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
42
+ self.eps = eps
43
+ self.elementwise_affine = elementwise_affine
44
+ if self.elementwise_affine:
45
+ self.weight = nn.Parameter(
46
+ torch.empty(self.normalized_shape, **factory_kwargs)
47
+ )
48
+ self.bias = nn.Parameter(
49
+ torch.empty(self.normalized_shape, **factory_kwargs)
50
+ )
51
+ else:
52
+ self.register_parameter("weight", None)
53
+ self.register_parameter("bias", None)
54
+
55
+ self.reset_parameters()
56
+
57
+ def reset_parameters(self) -> None:
58
+ if self.elementwise_affine:
59
+ nn.init.ones_(self.weight)
60
+ nn.init.zeros_(self.bias)
61
+
62
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
63
+ if isinstance(input, tuple):
64
+ input, embedding = input
65
+ return (
66
+ F.layer_norm(
67
+ input,
68
+ self.normalized_shape,
69
+ self.weight,
70
+ self.bias,
71
+ self.eps,
72
+ ),
73
+ embedding,
74
+ )
75
+
76
+ assert embedding is None
77
+ return F.layer_norm(
78
+ input, self.normalized_shape, self.weight, self.bias, self.eps
79
+ )
80
+
81
+ def extra_repr(self) -> str:
82
+ return (
83
+ "{normalized_shape}, eps={eps}, "
84
+ "elementwise_affine={elementwise_affine}".format(**self.__dict__)
85
+ )
86
+
87
+
88
+ class IdentityNorm(nn.Module):
89
+ def __init__(
90
+ self,
91
+ d_model: int,
92
+ eps: float = 1e-5,
93
+ device=None,
94
+ dtype=None,
95
+ ) -> None:
96
+ super(IdentityNorm, self).__init__()
97
+
98
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
99
+ if isinstance(input, tuple):
100
+ return input
101
+
102
+ assert embedding is None
103
+ return input
104
+
105
+
106
+ class TransformerEncoder(nn.Module):
107
+ r"""TransformerEncoder is a stack of N encoder layers. Users can build the
108
+ BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
109
+
110
+ Args:
111
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
112
+ num_layers: the number of sub-encoder-layers in the encoder (required).
113
+ norm: the layer normalization component (optional).
114
+ enable_nested_tensor: if True, input will automatically convert to nested tensor
115
+ (and convert back on output). This will improve the overall performance of
116
+ TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
117
+
118
+ Examples::
119
+ >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
120
+ >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
121
+ >>> src = torch.rand(10, 32, 512)
122
+ >>> out = transformer_encoder(src)
123
+ """
124
+ __constants__ = ["norm"]
125
+
126
+ def __init__(self, encoder_layer, num_layers, norm=None):
127
+ super(TransformerEncoder, self).__init__()
128
+ self.layers = _get_clones(encoder_layer, num_layers)
129
+ self.num_layers = num_layers
130
+ self.norm = norm
131
+
132
+ def forward(
133
+ self,
134
+ src: Tensor,
135
+ mask: Optional[Tensor] = None,
136
+ src_key_padding_mask: Optional[Tensor] = None,
137
+ return_layer_states: bool = False,
138
+ cache=None,
139
+ ) -> Tensor:
140
+ r"""Pass the input through the encoder layers in turn.
141
+
142
+ Args:
143
+ src: the sequence to the encoder (required).
144
+ mask: the mask for the src sequence (optional).
145
+ src_key_padding_mask: the mask for the src keys per batch (optional).
146
+ return_layer_states: return layers' state (optional).
147
+
148
+ Shape:
149
+ see the docs in Transformer class.
150
+ """
151
+ if return_layer_states:
152
+ layer_states = [] # layers' output
153
+ output = src
154
+ for mod in self.layers:
155
+ output = mod(
156
+ output,
157
+ src_mask=mask,
158
+ src_key_padding_mask=src_key_padding_mask,
159
+ cache=cache,
160
+ )
161
+ layer_states.append(output[0])
162
+
163
+ if self.norm is not None:
164
+ output = self.norm(output)
165
+
166
+ return layer_states, output
167
+
168
+ output = src
169
+ for mod in self.layers:
170
+ output = mod(
171
+ output,
172
+ src_mask=mask,
173
+ src_key_padding_mask=src_key_padding_mask,
174
+ cache=cache,
175
+ )
176
+
177
+ if self.norm is not None:
178
+ output = self.norm(output)
179
+
180
+ return output
181
+
182
+
183
+ class TransformerEncoderLayer(nn.Module):
184
+ __constants__ = ["batch_first", "norm_first"]
185
+
186
+ def __init__(
187
+ self,
188
+ d_model: int,
189
+ nhead: int,
190
+ dim_feedforward: int = 2048,
191
+ dropout: float = 0.1,
192
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
193
+ batch_first: bool = False,
194
+ norm_first: bool = False,
195
+ device=None,
196
+ dtype=None,
197
+ linear1_self_attention_cls: nn.Module = nn.Linear,
198
+ linear2_self_attention_cls: nn.Module = nn.Linear,
199
+ linear1_feedforward_cls: nn.Module = nn.Linear,
200
+ linear2_feedforward_cls: nn.Module = nn.Linear,
201
+ layer_norm_cls: nn.Module = LayerNorm,
202
+ layer_norm_eps: float = 1e-5,
203
+ adaptive_layer_norm=False,
204
+ ) -> None:
205
+ factory_kwargs = {"device": device, "dtype": dtype}
206
+ super(TransformerEncoderLayer, self).__init__()
207
+ # print(233333333333,d_model,nhead)
208
+ # import os
209
+ # os._exit(2333333)
210
+ self.self_attn = MultiheadAttention(
211
+ d_model, # 512 16
212
+ nhead,
213
+ dropout=dropout,
214
+ batch_first=batch_first,
215
+ linear1_cls=linear1_self_attention_cls,
216
+ linear2_cls=linear2_self_attention_cls,
217
+ **factory_kwargs,
218
+ )
219
+
220
+ # Implementation of Feedforward model
221
+ self.linear1 = linear1_feedforward_cls(
222
+ d_model, dim_feedforward, **factory_kwargs
223
+ )
224
+ self.dropout = nn.Dropout(dropout)
225
+ self.linear2 = linear2_feedforward_cls(
226
+ dim_feedforward, d_model, **factory_kwargs
227
+ )
228
+
229
+ self.norm_first = norm_first
230
+ self.dropout1 = nn.Dropout(dropout)
231
+ self.dropout2 = nn.Dropout(dropout)
232
+
233
+ # Legacy string support for activation function.
234
+ if isinstance(activation, str):
235
+ activation = _get_activation_fn(activation)
236
+ elif isinstance(activation, partial):
237
+ activation = activation(d_model)
238
+ elif activation == BalancedDoubleSwish:
239
+ activation = BalancedDoubleSwish(d_model)
240
+
241
+ # # We can't test self.activation in forward() in TorchScript,
242
+ # # so stash some information about it instead.
243
+ # if activation is F.relu or isinstance(activation, torch.nn.ReLU):
244
+ # self.activation_relu_or_gelu = 1
245
+ # elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
246
+ # self.activation_relu_or_gelu = 2
247
+ # else:
248
+ # self.activation_relu_or_gelu = 0
249
+ self.activation = activation
250
+
251
+ norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
252
+ if layer_norm_cls == IdentityNorm:
253
+ norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
254
+ else:
255
+ norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
256
+
257
+ if adaptive_layer_norm:
258
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
259
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
260
+ else:
261
+ self.norm1 = norm1
262
+ self.norm2 = norm2
263
+
264
+ def __setstate__(self, state):
265
+ super(TransformerEncoderLayer, self).__setstate__(state)
266
+ if not hasattr(self, "activation"):
267
+ self.activation = F.relu
268
+
269
+ def forward(
270
+ self,
271
+ src: Tensor,
272
+ src_mask: Optional[Tensor] = None,
273
+ src_key_padding_mask: Optional[Tensor] = None,
274
+ cache=None,
275
+ ) -> Tensor:
276
+ r"""Pass the input through the encoder layer.
277
+
278
+ Args:
279
+ src: the sequence to the encoder layer (required).
280
+ src_mask: the mask for the src sequence (optional).
281
+ src_key_padding_mask: the mask for the src keys per batch (optional).
282
+
283
+ Shape:
284
+ see the docs in Transformer class.
285
+ """
286
+ x, stage_embedding = src, None
287
+ is_src_tuple = False
288
+ if isinstance(src, tuple):
289
+ x, stage_embedding = src
290
+ is_src_tuple = True
291
+
292
+ if src_key_padding_mask is not None:
293
+ _skpm_dtype = src_key_padding_mask.dtype
294
+ if _skpm_dtype != torch.bool and not torch.is_floating_point(
295
+ src_key_padding_mask
296
+ ):
297
+ raise AssertionError(
298
+ "only bool and floating types of key_padding_mask are supported"
299
+ )
300
+
301
+ if self.norm_first:
302
+ x = x + self._sa_block(
303
+ self.norm1(x, stage_embedding),
304
+ src_mask,
305
+ src_key_padding_mask,
306
+ cache=cache,
307
+ )
308
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
309
+ else:
310
+ x = self.norm1(
311
+ x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache),
312
+ stage_embedding,
313
+ )
314
+ x = self.norm2(x + self._ff_block(x), stage_embedding)
315
+
316
+ if is_src_tuple:
317
+ return (x, stage_embedding)
318
+ return x
319
+
320
+ # self-attention block
321
+ def _sa_block(
322
+ self,
323
+ x: Tensor,
324
+ attn_mask: Optional[Tensor],
325
+ key_padding_mask: Optional[Tensor],
326
+ cache=None,
327
+ ) -> Tensor:
328
+ # print(x.shape,attn_mask.shape,key_padding_mask)
329
+ # torch.Size([1, 188, 512]) torch.Size([188, 188]) None
330
+ # import os
331
+ # os._exit(23333)
332
+ x = self.self_attn(
333
+ x,
334
+ x,
335
+ x,
336
+ attn_mask=attn_mask,
337
+ key_padding_mask=key_padding_mask,
338
+ need_weights=False,
339
+ cache=cache,
340
+ )[0]
341
+ return self.dropout1(x)
342
+
343
+ # feed forward block
344
+ def _ff_block(self, x: Tensor) -> Tensor:
345
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
346
+ return self.dropout2(x)
347
+
348
+
349
+ class AdaptiveLayerNorm(nn.Module):
350
+ r"""Adaptive Layer Normalization"""
351
+
352
+ def __init__(self, d_model, norm) -> None:
353
+ super(AdaptiveLayerNorm, self).__init__()
354
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
355
+ self.norm = norm
356
+ self.d_model = d_model
357
+ self.eps = self.norm.eps
358
+
359
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
360
+ if isinstance(input, tuple):
361
+ input, embedding = input
362
+ weight, bias = torch.split(
363
+ self.project_layer(embedding),
364
+ split_size_or_sections=self.d_model,
365
+ dim=-1,
366
+ )
367
+ return (weight * self.norm(input) + bias, embedding)
368
+
369
+ weight, bias = torch.split(
370
+ self.project_layer(embedding),
371
+ split_size_or_sections=self.d_model,
372
+ dim=-1,
373
+ )
374
+ return weight * self.norm(input) + bias
375
+
376
+
377
+ def _get_clones(module, N):
378
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
AR/modules/transformer_onnx.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py
2
+ import copy
3
+ import numbers
4
+ from functools import partial
5
+ from typing import Any
6
+ from typing import Callable
7
+ from typing import List
8
+ from typing import Optional
9
+ from typing import Tuple
10
+ from typing import Union
11
+
12
+ import torch
13
+ from AR.modules.activation_onnx import MultiheadAttention
14
+ from AR.modules.scaling import BalancedDoubleSwish
15
+ from torch import nn
16
+ from torch import Tensor
17
+ from torch.nn import functional as F
18
+
19
+ _shape_t = Union[int, List[int], torch.Size]
20
+
21
+
22
+ class LayerNorm(nn.Module):
23
+ __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
24
+ normalized_shape: Tuple[int, ...]
25
+ eps: float
26
+ elementwise_affine: bool
27
+
28
+ def __init__(
29
+ self,
30
+ normalized_shape: _shape_t,
31
+ eps: float = 1e-5,
32
+ elementwise_affine: bool = True,
33
+ device=None,
34
+ dtype=None,
35
+ ) -> None:
36
+ factory_kwargs = {"device": device, "dtype": dtype}
37
+ super(LayerNorm, self).__init__()
38
+ if isinstance(normalized_shape, numbers.Integral):
39
+ # mypy error: incompatible types in assignment
40
+ normalized_shape = (normalized_shape,) # type: ignore[assignment]
41
+ self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
42
+ self.eps = eps
43
+ self.elementwise_affine = elementwise_affine
44
+ if self.elementwise_affine:
45
+ self.weight = nn.Parameter(
46
+ torch.empty(self.normalized_shape, **factory_kwargs)
47
+ )
48
+ self.bias = nn.Parameter(
49
+ torch.empty(self.normalized_shape, **factory_kwargs)
50
+ )
51
+ else:
52
+ self.register_parameter("weight", None)
53
+ self.register_parameter("bias", None)
54
+
55
+ self.reset_parameters()
56
+
57
+ def reset_parameters(self) -> None:
58
+ if self.elementwise_affine:
59
+ nn.init.ones_(self.weight)
60
+ nn.init.zeros_(self.bias)
61
+
62
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
63
+ if isinstance(input, tuple):
64
+ input, embedding = input
65
+ return (
66
+ F.layer_norm(
67
+ input,
68
+ self.normalized_shape,
69
+ self.weight,
70
+ self.bias,
71
+ self.eps,
72
+ ),
73
+ embedding,
74
+ )
75
+
76
+ assert embedding is None
77
+ return F.layer_norm(
78
+ input, self.normalized_shape, self.weight, self.bias, self.eps
79
+ )
80
+
81
+ def extra_repr(self) -> str:
82
+ return (
83
+ "{normalized_shape}, eps={eps}, "
84
+ "elementwise_affine={elementwise_affine}".format(**self.__dict__)
85
+ )
86
+
87
+
88
+ class IdentityNorm(nn.Module):
89
+ def __init__(
90
+ self,
91
+ d_model: int,
92
+ eps: float = 1e-5,
93
+ device=None,
94
+ dtype=None,
95
+ ) -> None:
96
+ super(IdentityNorm, self).__init__()
97
+
98
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
99
+ if isinstance(input, tuple):
100
+ return input
101
+
102
+ assert embedding is None
103
+ return input
104
+
105
+
106
+ class TransformerEncoder(nn.Module):
107
+ r"""TransformerEncoder is a stack of N encoder layers. Users can build the
108
+ BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
109
+
110
+ Args:
111
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
112
+ num_layers: the number of sub-encoder-layers in the encoder (required).
113
+ norm: the layer normalization component (optional).
114
+ enable_nested_tensor: if True, input will automatically convert to nested tensor
115
+ (and convert back on output). This will improve the overall performance of
116
+ TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
117
+
118
+ Examples::
119
+ >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
120
+ >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
121
+ >>> src = torch.rand(10, 32, 512)
122
+ >>> out = transformer_encoder(src)
123
+ """
124
+ __constants__ = ["norm"]
125
+
126
+ def __init__(self, encoder_layer, num_layers, norm=None):
127
+ super(TransformerEncoder, self).__init__()
128
+ self.layers = _get_clones(encoder_layer, num_layers)
129
+ self.num_layers = num_layers
130
+ self.norm = norm
131
+
132
+ def forward(
133
+ self,
134
+ src: Tensor,
135
+ mask: Optional[Tensor] = None,
136
+ src_key_padding_mask: Optional[Tensor] = None,
137
+ return_layer_states: bool = False,
138
+ cache=None,
139
+ ) -> Tensor:
140
+ output = src
141
+ for mod in self.layers:
142
+ output = mod(
143
+ output,
144
+ src_mask=mask,
145
+ src_key_padding_mask=src_key_padding_mask,
146
+ cache=cache,
147
+ )
148
+
149
+ if self.norm is not None:
150
+ output = self.norm(output)
151
+
152
+ return output
153
+
154
+
155
+ class TransformerEncoderLayer(nn.Module):
156
+ __constants__ = ["batch_first", "norm_first"]
157
+ def __init__(
158
+ self,
159
+ d_model: int,
160
+ nhead: int,
161
+ dim_feedforward: int = 2048,
162
+ dropout: float = 0.1,
163
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
164
+ batch_first: bool = False,
165
+ norm_first: bool = False,
166
+ device=None,
167
+ dtype=None,
168
+ linear1_self_attention_cls: nn.Module = nn.Linear,
169
+ linear2_self_attention_cls: nn.Module = nn.Linear,
170
+ linear1_feedforward_cls: nn.Module = nn.Linear,
171
+ linear2_feedforward_cls: nn.Module = nn.Linear,
172
+ layer_norm_cls: nn.Module = LayerNorm,
173
+ layer_norm_eps: float = 1e-5,
174
+ adaptive_layer_norm=False,
175
+ ) -> None:
176
+ factory_kwargs = {"device": device, "dtype": dtype}
177
+ super(TransformerEncoderLayer, self).__init__()
178
+ self.self_attn = MultiheadAttention(
179
+ d_model, # 512 16
180
+ nhead,
181
+ dropout=dropout,
182
+ batch_first=batch_first,
183
+ linear1_cls=linear1_self_attention_cls,
184
+ linear2_cls=linear2_self_attention_cls,
185
+ **factory_kwargs,
186
+ )
187
+ self.linear1 = linear1_feedforward_cls(
188
+ d_model, dim_feedforward, **factory_kwargs
189
+ )
190
+ self.dropout = nn.Dropout(dropout)
191
+ self.linear2 = linear2_feedforward_cls(
192
+ dim_feedforward, d_model, **factory_kwargs
193
+ )
194
+ self.norm_first = norm_first
195
+ self.dropout1 = nn.Dropout(dropout)
196
+ self.dropout2 = nn.Dropout(dropout)
197
+ if isinstance(activation, str):
198
+ activation = _get_activation_fn(activation)
199
+ elif isinstance(activation, partial):
200
+ activation = activation(d_model)
201
+ elif activation == BalancedDoubleSwish:
202
+ activation = BalancedDoubleSwish(d_model)
203
+ self.activation = activation
204
+
205
+ norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
206
+ if layer_norm_cls == IdentityNorm:
207
+ norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
208
+ else:
209
+ norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
210
+
211
+ if adaptive_layer_norm:
212
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
213
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
214
+ else:
215
+ self.norm1 = norm1
216
+ self.norm2 = norm2
217
+
218
+ def __setstate__(self, state):
219
+ super(TransformerEncoderLayer, self).__setstate__(state)
220
+ if not hasattr(self, "activation"):
221
+ self.activation = F.relu
222
+
223
+ def forward(
224
+ self,
225
+ src: Tensor,
226
+ src_mask: Optional[Tensor] = None,
227
+ src_key_padding_mask: Optional[Tensor] = None,
228
+ cache=None,
229
+ ) -> Tensor:
230
+ x = src
231
+ stage_embedding = None
232
+ x = self.norm1(
233
+ x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache),
234
+ stage_embedding,
235
+ )
236
+ x = self.norm2(x + self._ff_block(x), stage_embedding)
237
+
238
+ return x
239
+
240
+ def _sa_block(
241
+ self,
242
+ x: Tensor,
243
+ attn_mask: Optional[Tensor],
244
+ key_padding_mask: Optional[Tensor],
245
+ cache=None,
246
+ ) -> Tensor:
247
+ x = self.self_attn(
248
+ x,
249
+ x,
250
+ x,
251
+ attn_mask=attn_mask,
252
+ key_padding_mask=key_padding_mask,
253
+ need_weights=False,
254
+ cache=cache,
255
+ )
256
+ return self.dropout1(x)
257
+
258
+ def _ff_block(self, x: Tensor) -> Tensor:
259
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
260
+ return self.dropout2(x)
261
+
262
+
263
+ class AdaptiveLayerNorm(nn.Module):
264
+ r"""Adaptive Layer Normalization"""
265
+
266
+ def __init__(self, d_model, norm) -> None:
267
+ super(AdaptiveLayerNorm, self).__init__()
268
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
269
+ self.norm = norm
270
+ self.d_model = d_model
271
+ self.eps = self.norm.eps
272
+
273
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
274
+ if isinstance(input, tuple):
275
+ input, embedding = input
276
+ weight, bias = torch.split(
277
+ self.project_layer(embedding),
278
+ split_size_or_sections=self.d_model,
279
+ dim=-1,
280
+ )
281
+ return (weight * self.norm(input) + bias, embedding)
282
+
283
+ weight, bias = torch.split(
284
+ self.project_layer(embedding),
285
+ split_size_or_sections=self.d_model,
286
+ dim=-1,
287
+ )
288
+ return weight * self.norm(input) + bias
289
+
290
+
291
+ def _get_clones(module, N):
292
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
AR/text_processing/__init__.py ADDED
File without changes
AR/text_processing/phonemizer.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/phonemizer.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import itertools
4
+ import re
5
+ from typing import Dict
6
+ from typing import List
7
+
8
+ import regex
9
+ from gruut import sentences
10
+ from gruut.const import Sentence
11
+ from gruut.const import Word
12
+ from AR.text_processing.symbols import SYMBOL_TO_ID
13
+
14
+
15
+ class GruutPhonemizer:
16
+ def __init__(self, language: str):
17
+ self._phonemizer = sentences
18
+ self.lang = language
19
+ self.symbol_to_id = SYMBOL_TO_ID
20
+ self._special_cases_dict: Dict[str] = {
21
+ r"\.\.\.": "... ",
22
+ ";": "; ",
23
+ ":": ": ",
24
+ ",": ", ",
25
+ r"\.": ". ",
26
+ "!": "! ",
27
+ r"\?": "? ",
28
+ "—": "—",
29
+ "…": "… ",
30
+ "«": "«",
31
+ "»": "»",
32
+ }
33
+ self._punctuation_regexp: str = (
34
+ rf"([{''.join(self._special_cases_dict.keys())}])"
35
+ )
36
+
37
+ def _normalize_punctuation(self, text: str) -> str:
38
+ text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
39
+ text = regex.sub(rf"{self._punctuation_regexp}(\pL)", r"\1 \2", text)
40
+ text = regex.sub(r"\pZ+", r" ", text)
41
+ return text.strip()
42
+
43
+ def _convert_punctuation(self, word: Word) -> str:
44
+ if not word.phonemes:
45
+ return ""
46
+ if word.phonemes[0] in ["‖", "|"]:
47
+ return word.text.strip()
48
+
49
+ phonemes = "".join(word.phonemes)
50
+ # remove modifier characters ˈˌː with regex
51
+ phonemes = re.sub(r"[ˈˌː͡]", "", phonemes)
52
+ return phonemes.strip()
53
+
54
+ def phonemize(self, text: str, espeak: bool = False) -> str:
55
+ text_to_phonemize: str = self._normalize_punctuation(text)
56
+ sents: List[Sentence] = [
57
+ sent
58
+ for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)
59
+ ]
60
+ words: List[str] = [
61
+ self._convert_punctuation(word) for word in itertools.chain(*sents)
62
+ ]
63
+ return " ".join(words)
64
+
65
+ def transform(self, phonemes):
66
+ # convert phonemes to ids
67
+ # dictionary is in symbols.py
68
+ return [self.symbol_to_id[p] for p in phonemes if p in self.symbol_to_id.keys()]
69
+
70
+
71
+ if __name__ == "__main__":
72
+ phonemizer = GruutPhonemizer("en-us")
73
+ # text -> IPA
74
+ phonemes = phonemizer.phonemize("Hello, wor-ld ?")
75
+ print("phonemes:", phonemes)
76
+ print("len(phonemes):", len(phonemes))
77
+ phoneme_ids = phonemizer.transform(phonemes)
78
+ print("phoneme_ids:", phoneme_ids)
79
+ print("len(phoneme_ids):", len(phoneme_ids))
AR/text_processing/symbols.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/symbols.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ PAD = "_"
4
+ PUNCTUATION = ';:,.!?¡¿—…"«»“” '
5
+ LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
6
+ IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
7
+ SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
8
+ SPACE_ID = SYMBOLS.index(" ")
9
+ SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)}
10
+ ID_TO_SYMBOL = {i: s for i, s in enumerate(SYMBOLS)}
AR/utils/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def str2bool(str):
5
+ return True if str.lower() == 'true' else False
6
+
7
+
8
+ def get_newest_ckpt(string_list):
9
+ # 定义一个正则表达式模式,用于匹配字符串中的数字
10
+ pattern = r'epoch=(\d+)-step=(\d+)\.ckpt'
11
+
12
+ # 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表
13
+ extracted_info = []
14
+ for string in string_list:
15
+ match = re.match(pattern, string)
16
+ if match:
17
+ epoch = int(match.group(1))
18
+ step = int(match.group(2))
19
+ extracted_info.append((epoch, step, string))
20
+ # 按照 epoch 后面的数字和 step 后面的数字进行排序
21
+ sorted_info = sorted(
22
+ extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
23
+ # 获取最新的 ckpt 文件名
24
+ newest_ckpt = sorted_info[0][2]
25
+ return newest_ckpt
26
+
27
+
28
+ # 文本存在且不为空时 return True
29
+ def check_txt_file(file_path):
30
+ try:
31
+ with open(file_path, 'r') as file:
32
+ text = file.readline().strip()
33
+ assert text.strip() != ''
34
+ return text
35
+ except Exception:
36
+ return False
37
+ return False
AR/utils/initialize.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Initialize modules for espnet2 neural networks."""
3
+ import torch
4
+ from typeguard import check_argument_types
5
+
6
+
7
+ def initialize(model: torch.nn.Module, init: str):
8
+ """Initialize weights of a neural network module.
9
+
10
+ Parameters are initialized using the given method or distribution.
11
+
12
+ Custom initialization routines can be implemented into submodules
13
+ as function `espnet_initialization_fn` within the custom module.
14
+
15
+ Args:
16
+ model: Target.
17
+ init: Method of initialization.
18
+ """
19
+ assert check_argument_types()
20
+ print("init with", init)
21
+
22
+ # weight init
23
+ for p in model.parameters():
24
+ if p.dim() > 1:
25
+ if init == "xavier_uniform":
26
+ torch.nn.init.xavier_uniform_(p.data)
27
+ elif init == "xavier_normal":
28
+ torch.nn.init.xavier_normal_(p.data)
29
+ elif init == "kaiming_uniform":
30
+ torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
31
+ elif init == "kaiming_normal":
32
+ torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
33
+ else:
34
+ raise ValueError("Unknown initialization: " + init)
35
+ # bias init
36
+ for name, p in model.named_parameters():
37
+ if ".bias" in name and p.dim() == 1:
38
+ p.data.zero_()
AR/utils/io.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import torch
4
+ import yaml
5
+
6
+
7
+ def load_yaml_config(path):
8
+ with open(path) as f:
9
+ config = yaml.full_load(f)
10
+ return config
11
+
12
+
13
+ def save_config_to_yaml(config, path):
14
+ assert path.endswith(".yaml")
15
+ with open(path, "w") as f:
16
+ f.write(yaml.dump(config))
17
+ f.close()
18
+
19
+
20
+ def write_args(args, path):
21
+ args_dict = dict(
22
+ (name, getattr(args, name)) for name in dir(args) if not name.startswith("_")
23
+ )
24
+ with open(path, "a") as args_file:
25
+ args_file.write("==> torch version: {}\n".format(torch.__version__))
26
+ args_file.write(
27
+ "==> cudnn version: {}\n".format(torch.backends.cudnn.version())
28
+ )
29
+ args_file.write("==> Cmd:\n")
30
+ args_file.write(str(sys.argv))
31
+ args_file.write("\n==> args:\n")
32
+ for k, v in sorted(args_dict.items()):
33
+ args_file.write(" %s: %s\n" % (str(k), str(v)))
34
+ args_file.close()
README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: GPT SoVITS V2
3
+ emoji: 🤗
4
+ colorFrom: indigo
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 4.24.0
8
+ app_file: inference_webui.py
9
+ pinned: false
10
+ license: mit
11
+ python_version: 3.10.13
12
+ ---
13
+
14
+ GPT-SoVITS-v2 Zero-shot TTS demo
15
+
16
+ Input 3~10s reference audio to guide the time-bre, speed, emotion of voice, and generate the speech you want by input the inference text.
configs/s1.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ seed: 1234
3
+ epochs: 300
4
+ batch_size: 8
5
+ gradient_accumulation: 4
6
+ save_every_n_epoch: 1
7
+ precision: 16
8
+ gradient_clip: 1.0
9
+ optimizer:
10
+ lr: 0.01
11
+ lr_init: 0.00001
12
+ lr_end: 0.0001
13
+ warmup_steps: 2000
14
+ decay_steps: 40000
15
+ data:
16
+ max_eval_sample: 8
17
+ max_sec: 54
18
+ num_workers: 1
19
+ pad_val: 1024 # same with EOS in model
20
+ model:
21
+ vocab_size: 1025
22
+ phoneme_vocab_size: 512
23
+ embedding_dim: 512
24
+ hidden_dim: 512
25
+ head: 16
26
+ linear_units: 2048
27
+ n_layer: 12
28
+ dropout: 0
29
+ EOS: 1024
30
+ inference:
31
+ top_k: 5
configs/s1big.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ seed: 1234
3
+ epochs: 300
4
+ batch_size: 8
5
+ gradient_accumulation: 4
6
+ save_every_n_epoch: 1
7
+ precision: 16-mixed
8
+ gradient_clip: 1.0
9
+ optimizer:
10
+ lr: 0.01
11
+ lr_init: 0.00001
12
+ lr_end: 0.0001
13
+ warmup_steps: 2000
14
+ decay_steps: 40000
15
+ data:
16
+ max_eval_sample: 8
17
+ max_sec: 54
18
+ num_workers: 1
19
+ pad_val: 1024 # same with EOS in model
20
+ model:
21
+ vocab_size: 1025
22
+ phoneme_vocab_size: 512
23
+ embedding_dim: 1024
24
+ hidden_dim: 1024
25
+ head: 16
26
+ linear_units: 2048
27
+ n_layer: 16
28
+ dropout: 0
29
+ EOS: 1024
30
+ inference:
31
+ top_k: 5
configs/s1big2.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ seed: 1234
3
+ epochs: 300
4
+ batch_size: 12
5
+ gradient_accumulation: 4
6
+ save_every_n_epoch: 1
7
+ precision: 16-mixed
8
+ gradient_clip: 1.0
9
+ optimizer:
10
+ lr: 0.01
11
+ lr_init: 0.00001
12
+ lr_end: 0.0001
13
+ warmup_steps: 2000
14
+ decay_steps: 40000
15
+ data:
16
+ max_eval_sample: 8
17
+ max_sec: 54
18
+ num_workers: 1
19
+ pad_val: 1024 # same with EOS in model
20
+ model:
21
+ vocab_size: 1025
22
+ phoneme_vocab_size: 512
23
+ embedding_dim: 1024
24
+ hidden_dim: 1024
25
+ head: 16
26
+ linear_units: 2048
27
+ n_layer: 6
28
+ dropout: 0
29
+ EOS: 1024
30
+ inference:
31
+ top_k: 5
configs/s1longer-v2.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ seed: 1234
3
+ epochs: 20
4
+ batch_size: 8
5
+ save_every_n_epoch: 1
6
+ precision: 16-mixed
7
+ gradient_clip: 1.0
8
+ optimizer:
9
+ lr: 0.01
10
+ lr_init: 0.00001
11
+ lr_end: 0.0001
12
+ warmup_steps: 2000
13
+ decay_steps: 40000
14
+ data:
15
+ max_eval_sample: 8
16
+ max_sec: 54
17
+ num_workers: 4
18
+ pad_val: 1024 # same with EOS in model
19
+ model:
20
+ vocab_size: 1025
21
+ phoneme_vocab_size: 732
22
+ embedding_dim: 512
23
+ hidden_dim: 512
24
+ head: 16
25
+ linear_units: 2048
26
+ n_layer: 24
27
+ dropout: 0
28
+ EOS: 1024
29
+ random_bert: 0
30
+ inference:
31
+ top_k: 15
configs/s1longer.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ seed: 1234
3
+ epochs: 20
4
+ batch_size: 8
5
+ save_every_n_epoch: 1
6
+ precision: 16-mixed
7
+ gradient_clip: 1.0
8
+ optimizer:
9
+ lr: 0.01
10
+ lr_init: 0.00001
11
+ lr_end: 0.0001
12
+ warmup_steps: 2000
13
+ decay_steps: 40000
14
+ data:
15
+ max_eval_sample: 8
16
+ max_sec: 54
17
+ num_workers: 4
18
+ pad_val: 1024 # same with EOS in model
19
+ model:
20
+ vocab_size: 1025
21
+ phoneme_vocab_size: 512
22
+ embedding_dim: 512
23
+ hidden_dim: 512
24
+ head: 16
25
+ linear_units: 2048
26
+ n_layer: 24
27
+ dropout: 0
28
+ EOS: 1024
29
+ random_bert: 0
30
+ inference:
31
+ top_k: 5
configs/s1mq.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ seed: 1234
3
+ epochs: 100
4
+ batch_size: 6
5
+ gradient_accumulation: 4
6
+ save_every_n_epoch: 1
7
+ precision: 32
8
+ gradient_clip: 1.0
9
+ optimizer:
10
+ lr: 0.01
11
+ lr_init: 0.00001
12
+ lr_end: 0.0001
13
+ warmup_steps: 2000
14
+ decay_steps: 40000
15
+ data:
16
+ max_eval_sample: 8
17
+ max_sec: 40
18
+ num_workers: 1
19
+ pad_val: 1024 # same with EOS in model
20
+ model:
21
+ saving_path: "ckpt/"
22
+ resume_checkpoint: null
23
+ vocoder_config_path: "quantizer/new_ckpt/config.json"
24
+ vocoder_ckpt_path: "quantizer/new_ckpt/g_00600000"
25
+ datadir: "/home/liweiche/GigaSpeech/wavs"
26
+ metapath: "/home/liweiche/GigaSpeech/train2.json"
27
+ val_metapath: "/home/liweiche/GigaSpeech/dev2.json"
28
+ sampledir: "logs/"
29
+ pretrained_path: null
30
+ lr: 0.0001
31
+ batch_size: 200.0
32
+ train_bucket_size: 8192
33
+ training_step: 800000
34
+ optim_flat_percent: 0.0
35
+ warmup_step: 50
36
+ adam_beta1: 0.9
37
+ adam_beta2: 0.98
38
+ ffd_size: 3072
39
+ hidden_size: 768
40
+ enc_nlayers: 6
41
+ dec_nlayers: 6
42
+ nheads: 12
43
+ ar_layer: 4
44
+ ar_ffd_size: 1024
45
+ ar_hidden_size: 256
46
+ ar_nheads: 4
47
+ aligner_softmax_temp: 1.0
48
+ layer_norm_eps: 0.00001
49
+ speaker_embed_dropout: 0.05
50
+ label_smoothing: 0.0
51
+ val_check_interval: 5000
52
+ check_val_every_n_epoch: 1
53
+ precision: "fp16"
54
+ nworkers: 16
55
+ distributed: true
56
+ accelerator: "ddp"
57
+ version: null
58
+ accumulate_grad_batches: 1
59
+ use_repetition_token: true
60
+ use_repetition_gating: false
61
+ repetition_penalty: 1.0
62
+ sampling_temperature: 1.0
63
+ top_k: -1
64
+ min_top_k: 3
65
+ top_p: 0.8
66
+ sample_num: 4
67
+ length_penalty_max_length: 15000
68
+ length_penalty_max_prob: 0.95
69
+ max_input_length: 2048
70
+ max_output_length: 2000
71
+ sample_rate: 16000
72
+ n_codes: 1024
73
+ n_cluster_groups: 1
74
+ phone_context_window: 4
75
+ phoneset_size: 1000
76
+ inference:
77
+ top_k: 5
configs/s2.json ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 100,
4
+ "eval_interval": 500,
5
+ "seed": 1234,
6
+ "epochs": 100,
7
+ "learning_rate": 0.0001,
8
+ "betas": [
9
+ 0.8,
10
+ 0.99
11
+ ],
12
+ "eps": 1e-09,
13
+ "batch_size": 32,
14
+ "fp16_run": true,
15
+ "lr_decay": 0.999875,
16
+ "segment_size": 20480,
17
+ "init_lr_ratio": 1,
18
+ "warmup_epochs": 0,
19
+ "c_mel": 45,
20
+ "c_kl": 1.0,
21
+ "text_low_lr_rate": 0.4
22
+ },
23
+ "data": {
24
+ "max_wav_value": 32768.0,
25
+ "sampling_rate": 32000,
26
+ "filter_length": 2048,
27
+ "hop_length": 640,
28
+ "win_length": 2048,
29
+ "n_mel_channels": 128,
30
+ "mel_fmin": 0.0,
31
+ "mel_fmax": null,
32
+ "add_blank": true,
33
+ "n_speakers": 300,
34
+ "cleaned_text": true
35
+ },
36
+ "model": {
37
+ "inter_channels": 192,
38
+ "hidden_channels": 192,
39
+ "filter_channels": 768,
40
+ "n_heads": 2,
41
+ "n_layers": 6,
42
+ "kernel_size": 3,
43
+ "p_dropout": 0.1,
44
+ "resblock": "1",
45
+ "resblock_kernel_sizes": [
46
+ 3,
47
+ 7,
48
+ 11
49
+ ],
50
+ "resblock_dilation_sizes": [
51
+ [
52
+ 1,
53
+ 3,
54
+ 5
55
+ ],
56
+ [
57
+ 1,
58
+ 3,
59
+ 5
60
+ ],
61
+ [
62
+ 1,
63
+ 3,
64
+ 5
65
+ ]
66
+ ],
67
+ "upsample_rates": [
68
+ 10,
69
+ 8,
70
+ 2,
71
+ 2,
72
+ 2
73
+ ],
74
+ "upsample_initial_channel": 512,
75
+ "upsample_kernel_sizes": [
76
+ 16,
77
+ 16,
78
+ 8,
79
+ 2,
80
+ 2
81
+ ],
82
+ "n_layers_q": 3,
83
+ "use_spectral_norm": false,
84
+ "gin_channels": 512,
85
+ "semantic_frame_rate": "25hz",
86
+ "freeze_quantizer": true
87
+ },
88
+ "s2_ckpt_dir": "logs/s2/big2k1",
89
+ "content_module": "cnhubert"
90
+ }
configs/train.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpu:
2
+ n_card: 1
3
+ n_process_per_card: 2
4
+ io:
5
+ text_path: D:\RVC1006\GPT-SoVITS\GPT_SoVITS
6
+ save_every_n_epoch: 1
7
+ precision: 16-mixed
8
+ gradient_clip: 1.0
9
+ optimizer:
10
+ lr: 0.01
11
+ lr_init: 0.00001
12
+ lr_end: 0.0001
13
+ warmup_steps: 2000
14
+ decay_steps: 40000
15
+ data:
16
+ max_eval_sample: 8
17
+ max_sec: 54
18
+ num_workers: 1
19
+ pad_val: 1024 # same with EOS in model
20
+ model:
21
+ vocab_size: 1025
22
+ phoneme_vocab_size: 512
23
+ embedding_dim: 512
24
+ hidden_dim: 512
25
+ head: 16
26
+ linear_units: 2048
27
+ n_layer: 24
28
+ dropout: 0
29
+ EOS: 1024
30
+ random_bert: 0
31
+ inference:
32
+ top_k: 5
download.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import os, sys
2
+ now_dir = os.getcwd()
3
+ sys.path.insert(0, now_dir)
4
+ from .text.g2pw import G2PWPinyin
5
+ g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",v_to_u=False, neutral_tone_with_five=True)
feature_extractor/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from . import cnhubert, whisper_enc
2
+
3
+ content_module_map = {
4
+ 'cnhubert': cnhubert,
5
+ 'whisper': whisper_enc
6
+ }
feature_extractor/cnhubert.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import librosa
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import soundfile as sf
7
+ import os
8
+ from transformers import logging as tf_logging
9
+ tf_logging.set_verbosity_error()
10
+
11
+ import logging
12
+ logging.getLogger("numba").setLevel(logging.WARNING)
13
+
14
+ from transformers import (
15
+ Wav2Vec2FeatureExtractor,
16
+ HubertModel,
17
+ )
18
+
19
+ import utils
20
+ import torch.nn as nn
21
+
22
+ cnhubert_base_path = None
23
+
24
+
25
+ class CNHubert(nn.Module):
26
+ def __init__(self):
27
+ super().__init__()
28
+ if os.path.exists(cnhubert_base_path):...
29
+ else:raise FileNotFoundError(cnhubert_base_path)
30
+ self.model = HubertModel.from_pretrained(cnhubert_base_path, local_files_only=True)
31
+ self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
32
+ cnhubert_base_path, local_files_only=True
33
+ )
34
+
35
+ def forward(self, x):
36
+ input_values = self.feature_extractor(
37
+ x, return_tensors="pt", sampling_rate=16000
38
+ ).input_values.to(x.device)
39
+ feats = self.model(input_values)["last_hidden_state"]
40
+ return feats
41
+
42
+
43
+ # class CNHubertLarge(nn.Module):
44
+ # def __init__(self):
45
+ # super().__init__()
46
+ # self.model = HubertModel.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-hubert-large")
47
+ # self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-hubert-large")
48
+ # def forward(self, x):
49
+ # input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
50
+ # feats = self.model(input_values)["last_hidden_state"]
51
+ # return feats
52
+ #
53
+ # class CVec(nn.Module):
54
+ # def __init__(self):
55
+ # super().__init__()
56
+ # self.model = HubertModel.from_pretrained("/data/docker/liujing04/vc-webui-big/hubert_base")
57
+ # self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/vc-webui-big/hubert_base")
58
+ # def forward(self, x):
59
+ # input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
60
+ # feats = self.model(input_values)["last_hidden_state"]
61
+ # return feats
62
+ #
63
+ # class cnw2v2base(nn.Module):
64
+ # def __init__(self):
65
+ # super().__init__()
66
+ # self.model = Wav2Vec2Model.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-wav2vec2-base")
67
+ # self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-wav2vec2-base")
68
+ # def forward(self, x):
69
+ # input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
70
+ # feats = self.model(input_values)["last_hidden_state"]
71
+ # return feats
72
+
73
+
74
+ def get_model():
75
+ model = CNHubert()
76
+ model.eval()
77
+ return model
78
+
79
+
80
+ # def get_large_model():
81
+ # model = CNHubertLarge()
82
+ # model.eval()
83
+ # return model
84
+ #
85
+ # def get_model_cvec():
86
+ # model = CVec()
87
+ # model.eval()
88
+ # return model
89
+ #
90
+ # def get_model_cnw2v2base():
91
+ # model = cnw2v2base()
92
+ # model.eval()
93
+ # return model
94
+
95
+
96
+ def get_content(hmodel, wav_16k_tensor):
97
+ with torch.no_grad():
98
+ feats = hmodel(wav_16k_tensor)
99
+ return feats.transpose(1, 2)
100
+
101
+
102
+ if __name__ == "__main__":
103
+ model = get_model()
104
+ src_path = "/Users/Shared/原音频2.wav"
105
+ wav_16k_tensor = utils.load_wav_to_torch_and_resample(src_path, 16000)
106
+ model = model
107
+ wav_16k_tensor = wav_16k_tensor
108
+ feats = get_content(model, wav_16k_tensor)
109
+ print(feats.shape)
feature_extractor/whisper_enc.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def get_model():
5
+ import whisper
6
+
7
+ model = whisper.load_model("small", device="cpu")
8
+
9
+ return model.encoder
10
+
11
+
12
+ def get_content(model=None, wav_16k_tensor=None):
13
+ from whisper import log_mel_spectrogram, pad_or_trim
14
+
15
+ dev = next(model.parameters()).device
16
+ mel = log_mel_spectrogram(wav_16k_tensor).to(dev)[:, :3000]
17
+ # if torch.cuda.is_available():
18
+ # mel = mel.to(torch.float16)
19
+ feature_len = mel.shape[-1] // 2
20
+ assert mel.shape[-1] < 3000, "输入音频过长,只允许输入30以内音频"
21
+ with torch.no_grad():
22
+ feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[
23
+ :1, :feature_len, :
24
+ ].transpose(1, 2)
25
+ return feature
inference_cli.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import soundfile as sf
4
+
5
+ from tools.i18n.i18n import I18nAuto
6
+ from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
7
+
8
+ i18n = I18nAuto()
9
+
10
+ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path, ref_language, target_text_path, target_language, output_path):
11
+ # Read reference text
12
+ with open(ref_text_path, 'r', encoding='utf-8') as file:
13
+ ref_text = file.read()
14
+
15
+ # Read target text
16
+ with open(target_text_path, 'r', encoding='utf-8') as file:
17
+ target_text = file.read()
18
+
19
+ # Change model weights
20
+ change_gpt_weights(gpt_path=GPT_model_path)
21
+ change_sovits_weights(sovits_path=SoVITS_model_path)
22
+
23
+ # Synthesize audio
24
+ synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path,
25
+ prompt_text=ref_text,
26
+ prompt_language=i18n(ref_language),
27
+ text=target_text,
28
+ text_language=i18n(target_language), top_p=1, temperature=1)
29
+
30
+ result_list = list(synthesis_result)
31
+
32
+ if result_list:
33
+ last_sampling_rate, last_audio_data = result_list[-1]
34
+ output_wav_path = os.path.join(output_path, "output.wav")
35
+ sf.write(output_wav_path, last_audio_data, last_sampling_rate)
36
+ print(f"Audio saved to {output_wav_path}")
37
+
38
+ def main():
39
+ parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
40
+ parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
41
+ parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
42
+ parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
43
+ parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
44
+ parser.add_argument('--ref_language', required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio")
45
+ parser.add_argument('--target_text', required=True, help="Path to the target text file")
46
+ parser.add_argument('--target_language', required=True, choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"], help="Language of the target text")
47
+ parser.add_argument('--output_path', required=True, help="Path to the output directory")
48
+
49
+ args = parser.parse_args()
50
+
51
+ synthesize(args.gpt_model, args.sovits_model, args.ref_audio, args.ref_text, args.ref_language, args.target_text, args.target_language, args.output_path)
52
+
53
+ if __name__ == '__main__':
54
+ main()
55
+
inference_gui.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from PyQt5.QtCore import QEvent
4
+ from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QLineEdit, QPushButton, QTextEdit
5
+ from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QStatusBar, QComboBox
6
+ import soundfile as sf
7
+
8
+ from tools.i18n.i18n import I18nAuto
9
+ i18n = I18nAuto()
10
+
11
+ from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
12
+
13
+
14
+ class GPTSoVITSGUI(QMainWindow):
15
+ GPT_Path = gpt_path
16
+ SoVITS_Path = sovits_path
17
+
18
+ def __init__(self):
19
+ super().__init__()
20
+
21
+ self.setWindowTitle('GPT-SoVITS GUI')
22
+ self.setGeometry(800, 450, 950, 850)
23
+
24
+ self.setStyleSheet("""
25
+ QWidget {
26
+ background-color: #a3d3b1;
27
+ }
28
+
29
+ QTabWidget::pane {
30
+ background-color: #a3d3b1;
31
+ }
32
+
33
+ QTabWidget::tab-bar {
34
+ alignment: left;
35
+ }
36
+
37
+ QTabBar::tab {
38
+ background: #8da4bf;
39
+ color: #ffffff;
40
+ padding: 8px;
41
+ }
42
+
43
+ QTabBar::tab:selected {
44
+ background: #2a3f54;
45
+ }
46
+
47
+ QLabel {
48
+ color: #000000;
49
+ }
50
+
51
+ QPushButton {
52
+ background-color: #4CAF50;
53
+ color: white;
54
+ padding: 8px;
55
+ border: 1px solid #4CAF50;
56
+ border-radius: 4px;
57
+ }
58
+
59
+ QPushButton:hover {
60
+ background-color: #45a049;
61
+ border: 1px solid #45a049;
62
+ box-shadow: 2px 2px 2px rgba(0, 0, 0, 0.1);
63
+ }
64
+ """)
65
+
66
+ license_text = (
67
+ "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
68
+ "如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
69
+ license_label = QLabel(license_text)
70
+ license_label.setWordWrap(True)
71
+
72
+ self.GPT_model_label = QLabel("选择GPT模型:")
73
+ self.GPT_model_input = QLineEdit()
74
+ self.GPT_model_input.setPlaceholderText("拖拽或选择文件")
75
+ self.GPT_model_input.setText(self.GPT_Path)
76
+ self.GPT_model_input.setReadOnly(True)
77
+ self.GPT_model_button = QPushButton("选择GPT模型文件")
78
+ self.GPT_model_button.clicked.connect(self.select_GPT_model)
79
+
80
+ self.SoVITS_model_label = QLabel("选择SoVITS模型:")
81
+ self.SoVITS_model_input = QLineEdit()
82
+ self.SoVITS_model_input.setPlaceholderText("拖拽或选择文件")
83
+ self.SoVITS_model_input.setText(self.SoVITS_Path)
84
+ self.SoVITS_model_input.setReadOnly(True)
85
+ self.SoVITS_model_button = QPushButton("选择SoVITS模型文件")
86
+ self.SoVITS_model_button.clicked.connect(self.select_SoVITS_model)
87
+
88
+ self.ref_audio_label = QLabel("上传参考音频:")
89
+ self.ref_audio_input = QLineEdit()
90
+ self.ref_audio_input.setPlaceholderText("拖拽或选择文件")
91
+ self.ref_audio_input.setReadOnly(True)
92
+ self.ref_audio_button = QPushButton("选择音频文件")
93
+ self.ref_audio_button.clicked.connect(self.select_ref_audio)
94
+
95
+ self.ref_text_label = QLabel("参考音频文本:")
96
+ self.ref_text_input = QLineEdit()
97
+ self.ref_text_input.setPlaceholderText("直接输入文字或上传文本")
98
+ self.ref_text_button = QPushButton("上传文本")
99
+ self.ref_text_button.clicked.connect(self.upload_ref_text)
100
+
101
+ self.ref_language_label = QLabel("参考音频语言:")
102
+ self.ref_language_combobox = QComboBox()
103
+ self.ref_language_combobox.addItems(["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"])
104
+ self.ref_language_combobox.setCurrentText("多语种混合")
105
+
106
+ self.target_text_label = QLabel("合成目标文本:")
107
+ self.target_text_input = QLineEdit()
108
+ self.target_text_input.setPlaceholderText("直接输入文字或上传文本")
109
+ self.target_text_button = QPushButton("上传文本")
110
+ self.target_text_button.clicked.connect(self.upload_target_text)
111
+
112
+ self.target_language_label = QLabel("合成音频语言:")
113
+ self.target_language_combobox = QComboBox()
114
+ self.target_language_combobox.addItems(["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"])
115
+ self.target_language_combobox.setCurrentText("多语种混合")
116
+
117
+ self.output_label = QLabel("输出音频路径:")
118
+ self.output_input = QLineEdit()
119
+ self.output_input.setPlaceholderText("拖拽或选择文件")
120
+ self.output_input.setReadOnly(True)
121
+ self.output_button = QPushButton("选择文件夹")
122
+ self.output_button.clicked.connect(self.select_output_path)
123
+
124
+ self.output_text = QTextEdit()
125
+ self.output_text.setReadOnly(True)
126
+
127
+ self.add_drag_drop_events([
128
+ self.GPT_model_input,
129
+ self.SoVITS_model_input,
130
+ self.ref_audio_input,
131
+ self.ref_text_input,
132
+ self.target_text_input,
133
+ self.output_input,
134
+ ])
135
+
136
+ self.synthesize_button = QPushButton("合成")
137
+ self.synthesize_button.clicked.connect(self.synthesize)
138
+
139
+ self.clear_output_button = QPushButton("清空输出")
140
+ self.clear_output_button.clicked.connect(self.clear_output)
141
+
142
+ self.status_bar = QStatusBar()
143
+
144
+ main_layout = QVBoxLayout()
145
+
146
+ input_layout = QGridLayout(self)
147
+ input_layout.setSpacing(10)
148
+
149
+ input_layout.addWidget(license_label, 0, 0, 1, 3)
150
+
151
+ input_layout.addWidget(self.GPT_model_label, 1, 0)
152
+ input_layout.addWidget(self.GPT_model_input, 2, 0, 1, 2)
153
+ input_layout.addWidget(self.GPT_model_button, 2, 2)
154
+
155
+ input_layout.addWidget(self.SoVITS_model_label, 3, 0)
156
+ input_layout.addWidget(self.SoVITS_model_input, 4, 0, 1, 2)
157
+ input_layout.addWidget(self.SoVITS_model_button, 4, 2)
158
+
159
+ input_layout.addWidget(self.ref_audio_label, 5, 0)
160
+ input_layout.addWidget(self.ref_audio_input, 6, 0, 1, 2)
161
+ input_layout.addWidget(self.ref_audio_button, 6, 2)
162
+
163
+ input_layout.addWidget(self.ref_language_label, 7, 0)
164
+ input_layout.addWidget(self.ref_language_combobox, 8, 0, 1, 1)
165
+ input_layout.addWidget(self.ref_text_label, 9, 0)
166
+ input_layout.addWidget(self.ref_text_input, 10, 0, 1, 2)
167
+ input_layout.addWidget(self.ref_text_button, 10, 2)
168
+
169
+ input_layout.addWidget(self.target_language_label, 11, 0)
170
+ input_layout.addWidget(self.target_language_combobox, 12, 0, 1, 1)
171
+ input_layout.addWidget(self.target_text_label, 13, 0)
172
+ input_layout.addWidget(self.target_text_input, 14, 0, 1, 2)
173
+ input_layout.addWidget(self.target_text_button, 14, 2)
174
+
175
+ input_layout.addWidget(self.output_label, 15, 0)
176
+ input_layout.addWidget(self.output_input, 16, 0, 1, 2)
177
+ input_layout.addWidget(self.output_button, 16, 2)
178
+
179
+ main_layout.addLayout(input_layout)
180
+
181
+ output_layout = QVBoxLayout()
182
+ output_layout.addWidget(self.output_text)
183
+ main_layout.addLayout(output_layout)
184
+
185
+ main_layout.addWidget(self.synthesize_button)
186
+
187
+ main_layout.addWidget(self.clear_output_button)
188
+
189
+ main_layout.addWidget(self.status_bar)
190
+
191
+ self.central_widget = QWidget()
192
+ self.central_widget.setLayout(main_layout)
193
+ self.setCentralWidget(self.central_widget)
194
+
195
+ def dragEnterEvent(self, event):
196
+ if event.mimeData().hasUrls():
197
+ event.acceptProposedAction()
198
+
199
+ def dropEvent(self, event):
200
+ if event.mimeData().hasUrls():
201
+ file_paths = [url.toLocalFile() for url in event.mimeData().urls()]
202
+ if len(file_paths) == 1:
203
+ self.update_ref_audio(file_paths[0])
204
+ else:
205
+ self.update_ref_audio(", ".join(file_paths))
206
+
207
+ def add_drag_drop_events(self, widgets):
208
+ for widget in widgets:
209
+ widget.setAcceptDrops(True)
210
+ widget.installEventFilter(self)
211
+
212
+ def eventFilter(self, obj, event):
213
+ if event.type() in (QEvent.DragEnter, QEvent.Drop):
214
+ mime_data = event.mimeData()
215
+ if mime_data.hasUrls():
216
+ event.acceptProposedAction()
217
+
218
+ return super().eventFilter(obj, event)
219
+
220
+ def select_GPT_model(self):
221
+ file_path, _ = QFileDialog.getOpenFileName(self, "选择GPT模型文件", "", "GPT Files (*.ckpt)")
222
+ if file_path:
223
+ self.GPT_model_input.setText(file_path)
224
+
225
+ def select_SoVITS_model(self):
226
+ file_path, _ = QFileDialog.getOpenFileName(self, "选择SoVITS模型文件", "", "SoVITS Files (*.pth)")
227
+ if file_path:
228
+ self.SoVITS_model_input.setText(file_path)
229
+
230
+ def select_ref_audio(self):
231
+ file_path, _ = QFileDialog.getOpenFileName(self, "选择参考音频文件", "", "Audio Files (*.wav *.mp3)")
232
+ if file_path:
233
+ self.update_ref_audio(file_path)
234
+
235
+ def upload_ref_text(self):
236
+ file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
237
+ if file_path:
238
+ with open(file_path, 'r', encoding='utf-8') as file:
239
+ content = file.read()
240
+ self.ref_text_input.setText(content)
241
+
242
+ def upload_target_text(self):
243
+ file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
244
+ if file_path:
245
+ with open(file_path, 'r', encoding='utf-8') as file:
246
+ content = file.read()
247
+ self.target_text_input.setText(content)
248
+
249
+ def select_output_path(self):
250
+ options = QFileDialog.Options()
251
+ options |= QFileDialog.DontUseNativeDialog
252
+ options |= QFileDialog.ShowDirsOnly
253
+
254
+ folder_dialog = QFileDialog()
255
+ folder_dialog.setOptions(options)
256
+ folder_dialog.setFileMode(QFileDialog.Directory)
257
+
258
+ if folder_dialog.exec_():
259
+ folder_path = folder_dialog.selectedFiles()[0]
260
+ self.output_input.setText(folder_path)
261
+
262
+ def update_ref_audio(self, file_path):
263
+ self.ref_audio_input.setText(file_path)
264
+
265
+ def clear_output(self):
266
+ self.output_text.clear()
267
+
268
+ def synthesize(self):
269
+ GPT_model_path = self.GPT_model_input.text()
270
+ SoVITS_model_path = self.SoVITS_model_input.text()
271
+ ref_audio_path = self.ref_audio_input.text()
272
+ language_combobox = self.ref_language_combobox.currentText()
273
+ language_combobox = i18n(language_combobox)
274
+ ref_text = self.ref_text_input.text()
275
+ target_language_combobox = self.target_language_combobox.currentText()
276
+ target_language_combobox = i18n(target_language_combobox)
277
+ target_text = self.target_text_input.text()
278
+ output_path = self.output_input.text()
279
+
280
+ if GPT_model_path != self.GPT_Path:
281
+ change_gpt_weights(gpt_path=GPT_model_path)
282
+ self.GPT_Path = GPT_model_path
283
+ if SoVITS_model_path != self.SoVITS_Path:
284
+ change_sovits_weights(sovits_path=SoVITS_model_path)
285
+ self.SoVITS_Path = SoVITS_model_path
286
+
287
+ synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path,
288
+ prompt_text=ref_text,
289
+ prompt_language=language_combobox,
290
+ text=target_text,
291
+ text_language=target_language_combobox)
292
+
293
+ result_list = list(synthesis_result)
294
+
295
+ if result_list:
296
+ last_sampling_rate, last_audio_data = result_list[-1]
297
+ output_wav_path = os.path.join(output_path, "output.wav")
298
+ sf.write(output_wav_path, last_audio_data, last_sampling_rate)
299
+
300
+ result = "Audio saved to " + output_wav_path
301
+
302
+ self.status_bar.showMessage("合成完成!输出路径:" + output_wav_path, 5000)
303
+ self.output_text.append("处理结果:\n" + result)
304
+
305
+
306
+ if __name__ == '__main__':
307
+ app = QApplication(sys.argv)
308
+ mainWin = GPTSoVITSGUI()
309
+ mainWin.show()
310
+ sys.exit(app.exec_())
inference_webui.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ 按中英混合识别
3
+ 按日英混合识别
4
+ 多语种启动切分识别语种
5
+ 全部按中文识别
6
+ 全部按英文识别
7
+ 全部按日文识别
8
+ '''
9
+ import logging
10
+ import traceback
11
+
12
+ logging.getLogger("markdown_it").setLevel(logging.ERROR)
13
+ logging.getLogger("urllib3").setLevel(logging.ERROR)
14
+ logging.getLogger("httpcore").setLevel(logging.ERROR)
15
+ logging.getLogger("httpx").setLevel(logging.ERROR)
16
+ logging.getLogger("asyncio").setLevel(logging.ERROR)
17
+ logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
18
+ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
19
+ logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
20
+ import gradio.analytics as analytics
21
+ analytics.version_check = lambda:None
22
+ analytics.get_local_ip_address= lambda :"127.0.0.1"##不干掉本地联不通亚马逊的get_local_ip服务器
23
+ import nltk
24
+ nltk.download('averaged_perceptron_tagger_eng')
25
+ import LangSegment, os, re, sys, json
26
+ import pdb
27
+ import spaces
28
+ import torch
29
+
30
+ version="v2"#os.environ.get("version","v2")
31
+ cnhubert_base_path = os.environ.get(
32
+ "cnhubert_base_path", "pretrained_models/chinese-hubert-base"
33
+ )
34
+ bert_path = os.environ.get(
35
+ "bert_path", "pretrained_models/chinese-roberta-wwm-ext-large"
36
+ )
37
+
38
+ punctuation = set(['!', '?', '…', ',', '.', '-'," "])
39
+ import gradio as gr
40
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
41
+ import numpy as np
42
+ import librosa
43
+ from feature_extractor import cnhubert
44
+
45
+ cnhubert.cnhubert_base_path = cnhubert_base_path
46
+
47
+ from module.models import SynthesizerTrn
48
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
49
+ from text import cleaned_text_to_sequence
50
+ from text.cleaner import clean_text
51
+ from time import time as ttime
52
+ from module.mel_processing import spectrogram_torch
53
+ from tools.my_utils import load_audio
54
+ from tools.i18n.i18n import I18nAuto, scan_language_list
55
+
56
+ # language=os.environ.get("language","Auto")
57
+ # language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
58
+ i18n = I18nAuto(language="Auto")
59
+
60
+ # os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
61
+
62
+ if torch.cuda.is_available():
63
+ device = "cuda"
64
+ is_half = True # eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
65
+ else:
66
+ device = "cpu"
67
+ is_half=False
68
+
69
+ dict_language_v1 = {
70
+ i18n("中文"): "all_zh",#全部按中文识别
71
+ i18n("英文"): "en",#全部按英文识别#######不变
72
+ i18n("日文"): "all_ja",#全部按日文识别
73
+ i18n("中英混合"): "zh",#按中英混合识别####不变
74
+ i18n("日英混合"): "ja",#按日英混合识别####不变
75
+ i18n("多语种混合"): "auto",#多语种启动切分识别语种
76
+ }
77
+ dict_language_v2 = {
78
+ i18n("中文"): "all_zh",#全部按中文识别
79
+ i18n("英文"): "en",#全部按英文识别#######不变
80
+ i18n("日文"): "all_ja",#全部按日文识别
81
+ i18n("粤语"): "all_yue",#全部按中文识别
82
+ i18n("韩文"): "all_ko",#全部按韩文识别
83
+ i18n("中英混合"): "zh",#按中英混合识别####不变
84
+ i18n("日英混合"): "ja",#按日英混合识别####不变
85
+ i18n("粤英混合"): "yue",#按粤英混合识别####不变
86
+ i18n("韩英混合"): "ko",#按韩英混合识别####不变
87
+ i18n("多语种混合"): "auto",#多语种启动切分识别语种
88
+ i18n("多语种混合(粤语)"): "auto_yue",#多语种启动切分识别语种
89
+ }
90
+ dict_language = dict_language_v1 if version =='v1' else dict_language_v2
91
+
92
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
93
+ bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
94
+ if is_half == True:
95
+ bert_model = bert_model.half().to(device)
96
+ else:
97
+ bert_model = bert_model.to(device)
98
+
99
+
100
+ def get_bert_feature(text, word2ph):
101
+ with torch.no_grad():
102
+ inputs = tokenizer(text, return_tensors="pt")
103
+ for i in inputs:
104
+ inputs[i] = inputs[i].to(device)
105
+ res = bert_model(**inputs, output_hidden_states=True)
106
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
107
+ assert len(word2ph) == len(text)
108
+ phone_level_feature = []
109
+ for i in range(len(word2ph)):
110
+ repeat_feature = res[i].repeat(word2ph[i], 1)
111
+ phone_level_feature.append(repeat_feature)
112
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
113
+ return phone_level_feature.T
114
+
115
+
116
+ class DictToAttrRecursive(dict):
117
+ def __init__(self, input_dict):
118
+ super().__init__(input_dict)
119
+ for key, value in input_dict.items():
120
+ if isinstance(value, dict):
121
+ value = DictToAttrRecursive(value)
122
+ self[key] = value
123
+ setattr(self, key, value)
124
+
125
+ def __getattr__(self, item):
126
+ try:
127
+ return self[item]
128
+ except KeyError:
129
+ raise AttributeError(f"Attribute {item} not found")
130
+
131
+ def __setattr__(self, key, value):
132
+ if isinstance(value, dict):
133
+ value = DictToAttrRecursive(value)
134
+ super(DictToAttrRecursive, self).__setitem__(key, value)
135
+ super().__setattr__(key, value)
136
+
137
+ def __delattr__(self, item):
138
+ try:
139
+ del self[item]
140
+ except KeyError:
141
+ raise AttributeError(f"Attribute {item} not found")
142
+
143
+
144
+ ssl_model = cnhubert.get_model()
145
+ if is_half == True:
146
+ ssl_model = ssl_model.half().to(device)
147
+ else:
148
+ ssl_model = ssl_model.to(device)
149
+
150
+
151
+ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
152
+ global vq_model, hps, version, dict_language
153
+ dict_s2 = torch.load(sovits_path, map_location="cpu")
154
+ hps = dict_s2["config"]
155
+ hps = DictToAttrRecursive(hps)
156
+ hps.model.semantic_frame_rate = "25hz"
157
+ if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
158
+ hps.model.version = "v1"
159
+ else:
160
+ hps.model.version = "v2"
161
+ version = hps.model.version
162
+ # print("sovits版本:",hps.model.version)
163
+ vq_model = SynthesizerTrn(
164
+ hps.data.filter_length // 2 + 1,
165
+ hps.train.segment_size // hps.data.hop_length,
166
+ n_speakers=hps.data.n_speakers,
167
+ **hps.model
168
+ )
169
+ if ("pretrained" not in sovits_path):
170
+ del vq_model.enc_q
171
+ if is_half == True:
172
+ vq_model = vq_model.half().to(device)
173
+ else:
174
+ vq_model = vq_model.to(device)
175
+ vq_model.eval()
176
+ print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
177
+ dict_language = dict_language_v1 if version =='v1' else dict_language_v2
178
+ if prompt_language is not None and text_language is not None:
179
+ if prompt_language in list(dict_language.keys()):
180
+ prompt_text_update, prompt_language_update = {'__type__':'update'}, {'__type__':'update', 'value':prompt_language}
181
+ else:
182
+ prompt_text_update = {'__type__':'update', 'value':''}
183
+ prompt_language_update = {'__type__':'update', 'value':i18n("中文")}
184
+ if text_language in list(dict_language.keys()):
185
+ text_update, text_language_update = {'__type__':'update'}, {'__type__':'update', 'value':text_language}
186
+ else:
187
+ text_update = {'__type__':'update', 'value':''}
188
+ text_language_update = {'__type__':'update', 'value':i18n("中文")}
189
+ return {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update
190
+
191
+
192
+
193
+ change_sovits_weights("pretrained_models/gsv-v2final-pretrained/s2G2333k.pth")
194
+
195
+
196
+ def change_gpt_weights(gpt_path):
197
+ global hz, max_sec, t2s_model, config
198
+ hz = 50
199
+ dict_s1 = torch.load(gpt_path, map_location="cpu")
200
+ config = dict_s1["config"]
201
+ max_sec = config["data"]["max_sec"]
202
+ t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
203
+ t2s_model.load_state_dict(dict_s1["weight"])
204
+ if is_half == True:
205
+ t2s_model = t2s_model.half()
206
+ t2s_model = t2s_model.to(device)
207
+ t2s_model.eval()
208
+ total = sum([param.nelement() for param in t2s_model.parameters()])
209
+ print("Number of parameter: %.2fM" % (total / 1e6))
210
+
211
+
212
+ change_gpt_weights("pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt")
213
+
214
+
215
+ def get_spepc(hps, filename):
216
+ audio = load_audio(filename, int(hps.data.sampling_rate))
217
+ audio = torch.FloatTensor(audio)
218
+ maxx=audio.abs().max()
219
+ if(maxx>1):audio/=min(2,maxx)
220
+ audio_norm = audio
221
+ audio_norm = audio_norm.unsqueeze(0)
222
+ spec = spectrogram_torch(
223
+ audio_norm,
224
+ hps.data.filter_length,
225
+ hps.data.sampling_rate,
226
+ hps.data.hop_length,
227
+ hps.data.win_length,
228
+ center=False,
229
+ )
230
+ return spec
231
+
232
+ def clean_text_inf(text, language, version):
233
+ phones, word2ph, norm_text = clean_text(text, language, version)
234
+ phones = cleaned_text_to_sequence(phones, version)
235
+ return phones, word2ph, norm_text
236
+
237
+ dtype=torch.float16 if is_half == True else torch.float32
238
+ def get_bert_inf(phones, word2ph, norm_text, language):
239
+ language=language.replace("all_","")
240
+ if language == "zh":
241
+ bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
242
+ else:
243
+ bert = torch.zeros(
244
+ (1024, len(phones)),
245
+ dtype=torch.float16 if is_half == True else torch.float32,
246
+ ).to(device)
247
+
248
+ return bert
249
+
250
+
251
+ splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
252
+
253
+
254
+ def get_first(text):
255
+ pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
256
+ text = re.split(pattern, text)[0].strip()
257
+ return text
258
+
259
+ from text import chinese
260
+ def get_phones_and_bert(text,language,version):
261
+ if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
262
+ language = language.replace("all_","")
263
+ if language == "en":
264
+ LangSegment.setfilters(["en"])
265
+ formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
266
+ else:
267
+ # 因无法区别中日韩文汉字,以用户输入为准
268
+ formattext = text
269
+ while " " in formattext:
270
+ formattext = formattext.replace(" ", " ")
271
+ if language == "zh":
272
+ if re.search(r'[A-Za-z]', formattext):
273
+ formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
274
+ formattext = chinese.mix_text_normalize(formattext)
275
+ return get_phones_and_bert(formattext,"zh",version)
276
+ else:
277
+ phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
278
+ bert = get_bert_feature(norm_text, word2ph).to(device)
279
+ elif language == "yue" and re.search(r'[A-Za-z]', formattext):
280
+ formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
281
+ formattext = chinese.mix_text_normalize(formattext)
282
+ return get_phones_and_bert(formattext,"yue",version)
283
+ else:
284
+ phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
285
+ bert = torch.zeros(
286
+ (1024, len(phones)),
287
+ dtype=torch.float16 if is_half == True else torch.float32,
288
+ ).to(device)
289
+ elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
290
+ textlist=[]
291
+ langlist=[]
292
+ LangSegment.setfilters(["zh","ja","en","ko"])
293
+ if language == "auto":
294
+ for tmp in LangSegment.getTexts(text):
295
+ langlist.append(tmp["lang"])
296
+ textlist.append(tmp["text"])
297
+ elif language == "auto_yue":
298
+ for tmp in LangSegment.getTexts(text):
299
+ if tmp["lang"] == "zh":
300
+ tmp["lang"] = "yue"
301
+ langlist.append(tmp["lang"])
302
+ textlist.append(tmp["text"])
303
+ else:
304
+ for tmp in LangSegment.getTexts(text):
305
+ if tmp["lang"] == "en":
306
+ langlist.append(tmp["lang"])
307
+ else:
308
+ # 因无法区别中日韩文汉字,以用户输入为准
309
+ langlist.append(language)
310
+ textlist.append(tmp["text"])
311
+ print(textlist)
312
+ print(langlist)
313
+ phones_list = []
314
+ bert_list = []
315
+ norm_text_list = []
316
+ for i in range(len(textlist)):
317
+ lang = langlist[i]
318
+ phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
319
+ bert = get_bert_inf(phones, word2ph, norm_text, lang)
320
+ phones_list.append(phones)
321
+ norm_text_list.append(norm_text)
322
+ bert_list.append(bert)
323
+ bert = torch.cat(bert_list, dim=1)
324
+ phones = sum(phones_list, [])
325
+ norm_text = ''.join(norm_text_list)
326
+
327
+ return phones,bert.to(dtype),norm_text
328
+
329
+
330
+ def merge_short_text_in_array(texts, threshold):
331
+ if (len(texts)) < 2:
332
+ return texts
333
+ result = []
334
+ text = ""
335
+ for ele in texts:
336
+ text += ele
337
+ if len(text) >= threshold:
338
+ result.append(text)
339
+ text = ""
340
+ if (len(text) > 0):
341
+ if len(result) == 0:
342
+ result.append(text)
343
+ else:
344
+ result[len(result) - 1] += text
345
+ return result
346
+
347
+ ##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature
348
+ # cache_tokens={}#暂未实现清理机制
349
+ cache= {}
350
+ @torch.inference_mode()
351
+ @spaces.GPU
352
+ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False,speed=1,if_freeze=False,inp_refs=123):
353
+ global cache
354
+ if ref_wav_path:pass
355
+ else:gr.Warning(i18n('请上传参考音频'))
356
+ if text:pass
357
+ else:gr.Warning(i18n('请填入推理文本'))
358
+ t = []
359
+ if prompt_text is None or len(prompt_text) == 0:
360
+ ref_free = True
361
+ t0 = ttime()
362
+ prompt_language = dict_language[prompt_language]
363
+ text_language = dict_language[text_language]
364
+
365
+
366
+ if not ref_free:
367
+ prompt_text = prompt_text.strip("\n")
368
+ if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
369
+ print(i18n("实际输入的参考文本:"), prompt_text)
370
+ text = text.strip("\n")
371
+ if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
372
+
373
+ print(i18n("实际输入的目标文本:"), text)
374
+ zero_wav = np.zeros(
375
+ int(hps.data.sampling_rate * 0.3),
376
+ dtype=np.float16 if is_half == True else np.float32,
377
+ )
378
+ if not ref_free:
379
+ with torch.no_grad():
380
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000)
381
+ if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
382
+ gr.Warning(i18n("参考音频在3~10秒范围外,请更换!"))
383
+ raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
384
+ wav16k = torch.from_numpy(wav16k)
385
+ zero_wav_torch = torch.from_numpy(zero_wav)
386
+ if is_half == True:
387
+ wav16k = wav16k.half().to(device)
388
+ zero_wav_torch = zero_wav_torch.half().to(device)
389
+ else:
390
+ wav16k = wav16k.to(device)
391
+ zero_wav_torch = zero_wav_torch.to(device)
392
+ wav16k = torch.cat([wav16k, zero_wav_torch])
393
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
394
+ "last_hidden_state"
395
+ ].transpose(
396
+ 1, 2
397
+ ) # .float()
398
+ codes = vq_model.extract_latent(ssl_content)
399
+ prompt_semantic = codes[0, 0]
400
+ prompt = prompt_semantic.unsqueeze(0).to(device)
401
+
402
+ t1 = ttime()
403
+ t.append(t1-t0)
404
+
405
+ if (how_to_cut == i18n("凑四句一切")):
406
+ text = cut1(text)
407
+ elif (how_to_cut == i18n("凑50字一切")):
408
+ text = cut2(text)
409
+ elif (how_to_cut == i18n("按中文句号。切")):
410
+ text = cut3(text)
411
+ elif (how_to_cut == i18n("按英文句号.切")):
412
+ text = cut4(text)
413
+ elif (how_to_cut == i18n("按标点符号切")):
414
+ text = cut5(text)
415
+ while "\n\n" in text:
416
+ text = text.replace("\n\n", "\n")
417
+ print(i18n("实际输入的目标文本(切句后):"), text)
418
+ texts = text.split("\n")
419
+ texts = process_text(texts)
420
+ texts = merge_short_text_in_array(texts, 5)
421
+ audio_opt = []
422
+ if not ref_free:
423
+ phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language, version)
424
+
425
+ for i_text,text in enumerate(texts):
426
+ # 解决输入目标文本的空行导致报错的问题
427
+ if (len(text.strip()) == 0):
428
+ continue
429
+ if (text[-1] not in splits): text += "。" if text_language != "en" else "."
430
+ print(i18n("实际输入的目标文本(每句):"), text)
431
+ phones2,bert2,norm_text2=get_phones_and_bert(text, text_language, version)
432
+ print(i18n("前端处理后的文本(每句):"), norm_text2)
433
+ if not ref_free:
434
+ bert = torch.cat([bert1, bert2], 1)
435
+ all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
436
+ else:
437
+ bert = bert2
438
+ all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
439
+
440
+ bert = bert.to(device).unsqueeze(0)
441
+ all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
442
+
443
+ t2 = ttime()
444
+ # cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature)
445
+ # print(cache.keys(),if_freeze)
446
+ if(i_text in cache and if_freeze==True):pred_semantic=cache[i_text]
447
+ else:
448
+ with torch.no_grad():
449
+ pred_semantic, idx = t2s_model.model.infer_panel(
450
+ all_phoneme_ids,
451
+ all_phoneme_len,
452
+ None if ref_free else prompt,
453
+ bert,
454
+ # prompt_phone_len=ph_offset,
455
+ top_k=top_k,
456
+ top_p=top_p,
457
+ temperature=temperature,
458
+ early_stop_num=hz * max_sec,
459
+ )
460
+ pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
461
+ cache[i_text]=pred_semantic
462
+ t3 = ttime()
463
+ refers=[]
464
+ if(inp_refs):
465
+ for path in inp_refs:
466
+ try:
467
+ refer = get_spepc(hps, path.name).to(dtype).to(device)
468
+ refers.append(refer)
469
+ except:
470
+ traceback.print_exc()
471
+ if(len(refers)==0):refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
472
+ audio = (vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers,speed=speed).detach().cpu().numpy()[0, 0])
473
+ max_audio=np.abs(audio).max()#简单防止16bit爆音
474
+ if max_audio>1:audio/=max_audio
475
+ audio_opt.append(audio)
476
+ audio_opt.append(zero_wav)
477
+ t4 = ttime()
478
+ t.extend([t2 - t1,t3 - t2, t4 - t3])
479
+ t1 = ttime()
480
+ print("%.3f\t%.3f\t%.3f\t%.3f" %
481
+ (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3]))
482
+ )
483
+ yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
484
+ np.int16
485
+ )
486
+
487
+
488
+ def split(todo_text):
489
+ todo_text = todo_text.replace("……", "。").replace("——", ",")
490
+ if todo_text[-1] not in splits:
491
+ todo_text += "。"
492
+ i_split_head = i_split_tail = 0
493
+ len_text = len(todo_text)
494
+ todo_texts = []
495
+ while 1:
496
+ if i_split_head >= len_text:
497
+ break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
498
+ if todo_text[i_split_head] in splits:
499
+ i_split_head += 1
500
+ todo_texts.append(todo_text[i_split_tail:i_split_head])
501
+ i_split_tail = i_split_head
502
+ else:
503
+ i_split_head += 1
504
+ return todo_texts
505
+
506
+
507
+ def cut1(inp):
508
+ inp = inp.strip("\n")
509
+ inps = split(inp)
510
+ split_idx = list(range(0, len(inps), 4))
511
+ split_idx[-1] = None
512
+ if len(split_idx) > 1:
513
+ opts = []
514
+ for idx in range(len(split_idx) - 1):
515
+ opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
516
+ else:
517
+ opts = [inp]
518
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
519
+ return "\n".join(opts)
520
+
521
+
522
+ def cut2(inp):
523
+ inp = inp.strip("\n")
524
+ inps = split(inp)
525
+ if len(inps) < 2:
526
+ return inp
527
+ opts = []
528
+ summ = 0
529
+ tmp_str = ""
530
+ for i in range(len(inps)):
531
+ summ += len(inps[i])
532
+ tmp_str += inps[i]
533
+ if summ > 50:
534
+ summ = 0
535
+ opts.append(tmp_str)
536
+ tmp_str = ""
537
+ if tmp_str != "":
538
+ opts.append(tmp_str)
539
+ # print(opts)
540
+ if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
541
+ opts[-2] = opts[-2] + opts[-1]
542
+ opts = opts[:-1]
543
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
544
+ return "\n".join(opts)
545
+
546
+
547
+ def cut3(inp):
548
+ inp = inp.strip("\n")
549
+ opts = ["%s" % item for item in inp.strip("。").split("。")]
550
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
551
+ return "\n".join(opts)
552
+
553
+ def cut4(inp):
554
+ inp = inp.strip("\n")
555
+ opts = ["%s" % item for item in inp.strip(".").split(".")]
556
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
557
+ return "\n".join(opts)
558
+
559
+
560
+ # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
561
+ def cut5(inp):
562
+ inp = inp.strip("\n")
563
+ punds = {',', '.', ';', '?', '!', '、', ',', '。', '?', '!', ';', ':', '…'}
564
+ mergeitems = []
565
+ items = []
566
+
567
+ for i, char in enumerate(inp):
568
+ if char in punds:
569
+ if char == '.' and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
570
+ items.append(char)
571
+ else:
572
+ items.append(char)
573
+ mergeitems.append("".join(items))
574
+ items = []
575
+ else:
576
+ items.append(char)
577
+
578
+ if items:
579
+ mergeitems.append("".join(items))
580
+
581
+ opt = [item for item in mergeitems if not set(item).issubset(punds)]
582
+ return "\n".join(opt)
583
+
584
+
585
+ def custom_sort_key(s):
586
+ # 使用正则表达式提取字符串中的数字部分和非数字部分
587
+ parts = re.split('(\d+)', s)
588
+ # 将数字部分转换为整数,非数字部分保持不变
589
+ parts = [int(part) if part.isdigit() else part for part in parts]
590
+ return parts
591
+
592
+ def process_text(texts):
593
+ _text=[]
594
+ if all(text in [None, " ", "\n",""] for text in texts):
595
+ raise ValueError(i18n("请输入有效文本"))
596
+ for text in texts:
597
+ if text in [None, " ", ""]:
598
+ pass
599
+ else:
600
+ _text.append(text)
601
+ return _text
602
+
603
+
604
+ def html_center(text, label='p'):
605
+ return f"""<div style="text-align: center; margin: 100; padding: 50;">
606
+ <{label} style="margin: 0; padding: 0;">{text}</{label}>
607
+ </div>"""
608
+
609
+ def html_left(text, label='p'):
610
+ return f"""<div style="text-align: left; margin: 0; padding: 0;">
611
+ <{label} style="margin: 0; padding: 0;">{text}</{label}>
612
+ </div>"""
613
+
614
+
615
+ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
616
+ gr.Markdown(
617
+ value="""# GPT-SoVITS-v2 Zero-shot TTS demo
618
+ ## https://github.com/RVC-Boss/GPT-SoVITS
619
+ Input 3 to 10s reference audio to guide the time-bre, speed, emotion of voice, and generate the speech you want by input the inference text. <br>
620
+ 输入3至10秒的参考音频来引导待合成语音的音色、语速和情感,然后输入待合成目标文本,生成目标语音. <br>
621
+ Cross-lingual Support: Inference in languages different from the training dataset, currently supporting English, Japanese, Korean and Cantonese.<br>
622
+ 目前支持中日英韩粤跨语种合成。<br>
623
+ This demo is open source under the MIT license. The author does not have any control over it. Users who use the software and distribute the sounds exported by the software are solely responsible. If you do not agree with this clause, you cannot use or reference any codes and files within this demo. <br>
624
+ 本demo以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. 如不认可该条款, 则不能使用或引用该demo内的任何代码和文件.
625
+ """
626
+ )
627
+ with gr.Group():
628
+ gr.Markdown(html_center(i18n("*请上传并填写参考信息"),'h3'))
629
+ with gr.Row():
630
+ inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频,超过会报错!"), type="filepath")
631
+ with gr.Column():
632
+ ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True)
633
+ gr.Markdown(html_left(i18n("使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开。<br>开启后无视填写的参考文本。")))
634
+ prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="", lines=3, max_lines=3)
635
+ prompt_language = gr.Dropdown(
636
+ label=i18n("参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
637
+ )
638
+ inp_refs = gr.File(label=i18n("可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。"),file_count="multiple")
639
+ gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"),'h3'))
640
+ with gr.Row():
641
+ with gr.Column():
642
+ text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=26, max_lines=26)
643
+ with gr.Column():
644
+ text_language = gr.Dropdown(
645
+ label=i18n("需要合成的语种")+i18n(".限制范围越小判别效果越好。"), choices=list(dict_language.keys()), value=i18n("中文")
646
+ )
647
+ how_to_cut = gr.Dropdown(
648
+ label=i18n("怎么切"),
649
+ choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
650
+ value=i18n("凑四句一切"),
651
+ interactive=True
652
+ )
653
+ gr.Markdown(value=html_center(i18n("语速调整,高为更快")))
654
+ if_freeze=gr.Checkbox(label=i18n("是否直接对上次合成结果调整语速和音色。防止随机性。"), value=False, interactive=True,show_label=True)
655
+ speed = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label=i18n("语速"),value=1,interactive=True)
656
+ gr.Markdown(html_center(i18n("GPT采样参数(无参考文本时不要太低。不懂就用默认):")))
657
+ top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=15,interactive=True)
658
+ top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
659
+ temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
660
+ with gr.Row():
661
+ inference_button = gr.Button(i18n("合成语音"), variant="primary", size='lg')
662
+ output = gr.Audio(label=i18n("输出的语音"))
663
+
664
+ inference_button.click(
665
+ get_tts_wav,
666
+ [inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free,speed,if_freeze,inp_refs],
667
+ [output],
668
+ )
669
+
670
+ if __name__ == '__main__':
671
+ # app.queue(concurrency_count=511, max_size=1022).launch(
672
+ app.queue().launch(
673
+ server_name="0.0.0.0",
674
+ inbrowser=True,
675
+ # share=True,
676
+ # server_port=infer_ttswebui,
677
+ quiet=True,
678
+ )
module/__init__.py ADDED
File without changes
module/attentions.py ADDED
@@ -0,0 +1,709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from module import commons
7
+ from module.modules import LayerNorm
8
+
9
+
10
+ class Encoder(nn.Module):
11
+ def __init__(
12
+ self,
13
+ hidden_channels,
14
+ filter_channels,
15
+ n_heads,
16
+ n_layers,
17
+ kernel_size=1,
18
+ p_dropout=0.0,
19
+ window_size=4,
20
+ isflow=False,
21
+ **kwargs
22
+ ):
23
+ super().__init__()
24
+ self.hidden_channels = hidden_channels
25
+ self.filter_channels = filter_channels
26
+ self.n_heads = n_heads
27
+ self.n_layers = n_layers
28
+ self.kernel_size = kernel_size
29
+ self.p_dropout = p_dropout
30
+ self.window_size = window_size
31
+
32
+ self.drop = nn.Dropout(p_dropout)
33
+ self.attn_layers = nn.ModuleList()
34
+ self.norm_layers_1 = nn.ModuleList()
35
+ self.ffn_layers = nn.ModuleList()
36
+ self.norm_layers_2 = nn.ModuleList()
37
+ for i in range(self.n_layers):
38
+ self.attn_layers.append(
39
+ MultiHeadAttention(
40
+ hidden_channels,
41
+ hidden_channels,
42
+ n_heads,
43
+ p_dropout=p_dropout,
44
+ window_size=window_size,
45
+ )
46
+ )
47
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
48
+ self.ffn_layers.append(
49
+ FFN(
50
+ hidden_channels,
51
+ hidden_channels,
52
+ filter_channels,
53
+ kernel_size,
54
+ p_dropout=p_dropout,
55
+ )
56
+ )
57
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
58
+ if isflow:
59
+ cond_layer = torch.nn.Conv1d(
60
+ kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
61
+ )
62
+ self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
63
+ self.cond_layer = weight_norm_modules(cond_layer, name="weight")
64
+ self.gin_channels = kwargs["gin_channels"]
65
+
66
+ def forward(self, x, x_mask, g=None):
67
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
68
+ x = x * x_mask
69
+ if g is not None:
70
+ g = self.cond_layer(g)
71
+
72
+ for i in range(self.n_layers):
73
+ if g is not None:
74
+ x = self.cond_pre(x)
75
+ cond_offset = i * 2 * self.hidden_channels
76
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
77
+ x = commons.fused_add_tanh_sigmoid_multiply(
78
+ x, g_l, torch.IntTensor([self.hidden_channels])
79
+ )
80
+ y = self.attn_layers[i](x, x, attn_mask)
81
+ y = self.drop(y)
82
+ x = self.norm_layers_1[i](x + y)
83
+
84
+ y = self.ffn_layers[i](x, x_mask)
85
+ y = self.drop(y)
86
+ x = self.norm_layers_2[i](x + y)
87
+ x = x * x_mask
88
+ return x
89
+
90
+
91
+ class Decoder(nn.Module):
92
+ def __init__(
93
+ self,
94
+ hidden_channels,
95
+ filter_channels,
96
+ n_heads,
97
+ n_layers,
98
+ kernel_size=1,
99
+ p_dropout=0.0,
100
+ proximal_bias=False,
101
+ proximal_init=True,
102
+ **kwargs
103
+ ):
104
+ super().__init__()
105
+ self.hidden_channels = hidden_channels
106
+ self.filter_channels = filter_channels
107
+ self.n_heads = n_heads
108
+ self.n_layers = n_layers
109
+ self.kernel_size = kernel_size
110
+ self.p_dropout = p_dropout
111
+ self.proximal_bias = proximal_bias
112
+ self.proximal_init = proximal_init
113
+
114
+ self.drop = nn.Dropout(p_dropout)
115
+ self.self_attn_layers = nn.ModuleList()
116
+ self.norm_layers_0 = nn.ModuleList()
117
+ self.encdec_attn_layers = nn.ModuleList()
118
+ self.norm_layers_1 = nn.ModuleList()
119
+ self.ffn_layers = nn.ModuleList()
120
+ self.norm_layers_2 = nn.ModuleList()
121
+ for i in range(self.n_layers):
122
+ self.self_attn_layers.append(
123
+ MultiHeadAttention(
124
+ hidden_channels,
125
+ hidden_channels,
126
+ n_heads,
127
+ p_dropout=p_dropout,
128
+ proximal_bias=proximal_bias,
129
+ proximal_init=proximal_init,
130
+ )
131
+ )
132
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
133
+ self.encdec_attn_layers.append(
134
+ MultiHeadAttention(
135
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
136
+ )
137
+ )
138
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
139
+ self.ffn_layers.append(
140
+ FFN(
141
+ hidden_channels,
142
+ hidden_channels,
143
+ filter_channels,
144
+ kernel_size,
145
+ p_dropout=p_dropout,
146
+ causal=True,
147
+ )
148
+ )
149
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
150
+
151
+ def forward(self, x, x_mask, h, h_mask):
152
+ """
153
+ x: decoder input
154
+ h: encoder output
155
+ """
156
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
157
+ device=x.device, dtype=x.dtype
158
+ )
159
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
160
+ x = x * x_mask
161
+ for i in range(self.n_layers):
162
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
163
+ y = self.drop(y)
164
+ x = self.norm_layers_0[i](x + y)
165
+
166
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
167
+ y = self.drop(y)
168
+ x = self.norm_layers_1[i](x + y)
169
+
170
+ y = self.ffn_layers[i](x, x_mask)
171
+ y = self.drop(y)
172
+ x = self.norm_layers_2[i](x + y)
173
+ x = x * x_mask
174
+ return x
175
+
176
+
177
+ class MultiHeadAttention(nn.Module):
178
+ def __init__(
179
+ self,
180
+ channels,
181
+ out_channels,
182
+ n_heads,
183
+ p_dropout=0.0,
184
+ window_size=None,
185
+ heads_share=True,
186
+ block_length=None,
187
+ proximal_bias=False,
188
+ proximal_init=False,
189
+ ):
190
+ super().__init__()
191
+ assert channels % n_heads == 0
192
+
193
+ self.channels = channels
194
+ self.out_channels = out_channels
195
+ self.n_heads = n_heads
196
+ self.p_dropout = p_dropout
197
+ self.window_size = window_size
198
+ self.heads_share = heads_share
199
+ self.block_length = block_length
200
+ self.proximal_bias = proximal_bias
201
+ self.proximal_init = proximal_init
202
+ self.attn = None
203
+
204
+ self.k_channels = channels // n_heads
205
+ self.conv_q = nn.Conv1d(channels, channels, 1)
206
+ self.conv_k = nn.Conv1d(channels, channels, 1)
207
+ self.conv_v = nn.Conv1d(channels, channels, 1)
208
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
209
+ self.drop = nn.Dropout(p_dropout)
210
+
211
+ if window_size is not None:
212
+ n_heads_rel = 1 if heads_share else n_heads
213
+ rel_stddev = self.k_channels**-0.5
214
+ self.emb_rel_k = nn.Parameter(
215
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
216
+ * rel_stddev
217
+ )
218
+ self.emb_rel_v = nn.Parameter(
219
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
220
+ * rel_stddev
221
+ )
222
+
223
+ nn.init.xavier_uniform_(self.conv_q.weight)
224
+ nn.init.xavier_uniform_(self.conv_k.weight)
225
+ nn.init.xavier_uniform_(self.conv_v.weight)
226
+ if proximal_init:
227
+ with torch.no_grad():
228
+ self.conv_k.weight.copy_(self.conv_q.weight)
229
+ self.conv_k.bias.copy_(self.conv_q.bias)
230
+
231
+ def forward(self, x, c, attn_mask=None):
232
+ q = self.conv_q(x)
233
+ k = self.conv_k(c)
234
+ v = self.conv_v(c)
235
+
236
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
237
+
238
+ x = self.conv_o(x)
239
+ return x
240
+
241
+ def attention(self, query, key, value, mask=None):
242
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
243
+ b, d, t_s, t_t = (*key.size(), query.size(2))
244
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
245
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
246
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
247
+
248
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
249
+ if self.window_size is not None:
250
+ assert (
251
+ t_s == t_t
252
+ ), "Relative attention is only available for self-attention."
253
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
254
+ rel_logits = self._matmul_with_relative_keys(
255
+ query / math.sqrt(self.k_channels), key_relative_embeddings
256
+ )
257
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
258
+ scores = scores + scores_local
259
+ if self.proximal_bias:
260
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
261
+ scores = scores + self._attention_bias_proximal(t_s).to(
262
+ device=scores.device, dtype=scores.dtype
263
+ )
264
+ if mask is not None:
265
+ scores = scores.masked_fill(mask == 0, -1e4)
266
+ if self.block_length is not None:
267
+ assert (
268
+ t_s == t_t
269
+ ), "Local attention is only available for self-attention."
270
+ block_mask = (
271
+ torch.ones_like(scores)
272
+ .triu(-self.block_length)
273
+ .tril(self.block_length)
274
+ )
275
+ scores = scores.masked_fill(block_mask == 0, -1e4)
276
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
277
+ p_attn = self.drop(p_attn)
278
+ output = torch.matmul(p_attn, value)
279
+ if self.window_size is not None:
280
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
281
+ value_relative_embeddings = self._get_relative_embeddings(
282
+ self.emb_rel_v, t_s
283
+ )
284
+ output = output + self._matmul_with_relative_values(
285
+ relative_weights, value_relative_embeddings
286
+ )
287
+ output = (
288
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
289
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
290
+ return output, p_attn
291
+
292
+ def _matmul_with_relative_values(self, x, y):
293
+ """
294
+ x: [b, h, l, m]
295
+ y: [h or 1, m, d]
296
+ ret: [b, h, l, d]
297
+ """
298
+ ret = torch.matmul(x, y.unsqueeze(0))
299
+ return ret
300
+
301
+ def _matmul_with_relative_keys(self, x, y):
302
+ """
303
+ x: [b, h, l, d]
304
+ y: [h or 1, m, d]
305
+ ret: [b, h, l, m]
306
+ """
307
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
308
+ return ret
309
+
310
+ def _get_relative_embeddings(self, relative_embeddings, length):
311
+ max_relative_position = 2 * self.window_size + 1
312
+ # Pad first before slice to avoid using cond ops.
313
+ pad_length = max(length - (self.window_size + 1), 0)
314
+ slice_start_position = max((self.window_size + 1) - length, 0)
315
+ slice_end_position = slice_start_position + 2 * length - 1
316
+ if pad_length > 0:
317
+ padded_relative_embeddings = F.pad(
318
+ relative_embeddings,
319
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
320
+ )
321
+ else:
322
+ padded_relative_embeddings = relative_embeddings
323
+ used_relative_embeddings = padded_relative_embeddings[
324
+ :, slice_start_position:slice_end_position
325
+ ]
326
+ return used_relative_embeddings
327
+
328
+ def _relative_position_to_absolute_position(self, x):
329
+ """
330
+ x: [b, h, l, 2*l-1]
331
+ ret: [b, h, l, l]
332
+ """
333
+ batch, heads, length, _ = x.size()
334
+ # Concat columns of pad to shift from relative to absolute indexing.
335
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
336
+
337
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
338
+ x_flat = x.view([batch, heads, length * 2 * length])
339
+ x_flat = F.pad(
340
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
341
+ )
342
+
343
+ # Reshape and slice out the padded elements.
344
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
345
+ :, :, :length, length - 1 :
346
+ ]
347
+ return x_final
348
+
349
+ def _absolute_position_to_relative_position(self, x):
350
+ """
351
+ x: [b, h, l, l]
352
+ ret: [b, h, l, 2*l-1]
353
+ """
354
+ batch, heads, length, _ = x.size()
355
+ # padd along column
356
+ x = F.pad(
357
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
358
+ )
359
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
360
+ # add 0's in the beginning that will skew the elements after reshape
361
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
362
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
363
+ return x_final
364
+
365
+ def _attention_bias_proximal(self, length):
366
+ """Bias for self-attention to encourage attention to close positions.
367
+ Args:
368
+ length: an integer scalar.
369
+ Returns:
370
+ a Tensor with shape [1, 1, length, length]
371
+ """
372
+ r = torch.arange(length, dtype=torch.float32)
373
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
374
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
375
+
376
+
377
+ class FFN(nn.Module):
378
+ def __init__(
379
+ self,
380
+ in_channels,
381
+ out_channels,
382
+ filter_channels,
383
+ kernel_size,
384
+ p_dropout=0.0,
385
+ activation=None,
386
+ causal=False,
387
+ ):
388
+ super().__init__()
389
+ self.in_channels = in_channels
390
+ self.out_channels = out_channels
391
+ self.filter_channels = filter_channels
392
+ self.kernel_size = kernel_size
393
+ self.p_dropout = p_dropout
394
+ self.activation = activation
395
+ self.causal = causal
396
+
397
+ if causal:
398
+ self.padding = self._causal_padding
399
+ else:
400
+ self.padding = self._same_padding
401
+
402
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
403
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
404
+ self.drop = nn.Dropout(p_dropout)
405
+
406
+ def forward(self, x, x_mask):
407
+ x = self.conv_1(self.padding(x * x_mask))
408
+ if self.activation == "gelu":
409
+ x = x * torch.sigmoid(1.702 * x)
410
+ else:
411
+ x = torch.relu(x)
412
+ x = self.drop(x)
413
+ x = self.conv_2(self.padding(x * x_mask))
414
+ return x * x_mask
415
+
416
+ def _causal_padding(self, x):
417
+ if self.kernel_size == 1:
418
+ return x
419
+ pad_l = self.kernel_size - 1
420
+ pad_r = 0
421
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
422
+ x = F.pad(x, commons.convert_pad_shape(padding))
423
+ return x
424
+
425
+ def _same_padding(self, x):
426
+ if self.kernel_size == 1:
427
+ return x
428
+ pad_l = (self.kernel_size - 1) // 2
429
+ pad_r = self.kernel_size // 2
430
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
431
+ x = F.pad(x, commons.convert_pad_shape(padding))
432
+ return x
433
+
434
+
435
+ import torch.nn as nn
436
+ from torch.nn.utils import remove_weight_norm, weight_norm
437
+
438
+
439
+ class Depthwise_Separable_Conv1D(nn.Module):
440
+ def __init__(
441
+ self,
442
+ in_channels,
443
+ out_channels,
444
+ kernel_size,
445
+ stride=1,
446
+ padding=0,
447
+ dilation=1,
448
+ bias=True,
449
+ padding_mode="zeros", # TODO: refine this type
450
+ device=None,
451
+ dtype=None,
452
+ ):
453
+ super().__init__()
454
+ self.depth_conv = nn.Conv1d(
455
+ in_channels=in_channels,
456
+ out_channels=in_channels,
457
+ kernel_size=kernel_size,
458
+ groups=in_channels,
459
+ stride=stride,
460
+ padding=padding,
461
+ dilation=dilation,
462
+ bias=bias,
463
+ padding_mode=padding_mode,
464
+ device=device,
465
+ dtype=dtype,
466
+ )
467
+ self.point_conv = nn.Conv1d(
468
+ in_channels=in_channels,
469
+ out_channels=out_channels,
470
+ kernel_size=1,
471
+ bias=bias,
472
+ device=device,
473
+ dtype=dtype,
474
+ )
475
+
476
+ def forward(self, input):
477
+ return self.point_conv(self.depth_conv(input))
478
+
479
+ def weight_norm(self):
480
+ self.depth_conv = weight_norm(self.depth_conv, name="weight")
481
+ self.point_conv = weight_norm(self.point_conv, name="weight")
482
+
483
+ def remove_weight_norm(self):
484
+ self.depth_conv = remove_weight_norm(self.depth_conv, name="weight")
485
+ self.point_conv = remove_weight_norm(self.point_conv, name="weight")
486
+
487
+
488
+ class Depthwise_Separable_TransposeConv1D(nn.Module):
489
+ def __init__(
490
+ self,
491
+ in_channels,
492
+ out_channels,
493
+ kernel_size,
494
+ stride=1,
495
+ padding=0,
496
+ output_padding=0,
497
+ bias=True,
498
+ dilation=1,
499
+ padding_mode="zeros", # TODO: refine this type
500
+ device=None,
501
+ dtype=None,
502
+ ):
503
+ super().__init__()
504
+ self.depth_conv = nn.ConvTranspose1d(
505
+ in_channels=in_channels,
506
+ out_channels=in_channels,
507
+ kernel_size=kernel_size,
508
+ groups=in_channels,
509
+ stride=stride,
510
+ output_padding=output_padding,
511
+ padding=padding,
512
+ dilation=dilation,
513
+ bias=bias,
514
+ padding_mode=padding_mode,
515
+ device=device,
516
+ dtype=dtype,
517
+ )
518
+ self.point_conv = nn.Conv1d(
519
+ in_channels=in_channels,
520
+ out_channels=out_channels,
521
+ kernel_size=1,
522
+ bias=bias,
523
+ device=device,
524
+ dtype=dtype,
525
+ )
526
+
527
+ def forward(self, input):
528
+ return self.point_conv(self.depth_conv(input))
529
+
530
+ def weight_norm(self):
531
+ self.depth_conv = weight_norm(self.depth_conv, name="weight")
532
+ self.point_conv = weight_norm(self.point_conv, name="weight")
533
+
534
+ def remove_weight_norm(self):
535
+ remove_weight_norm(self.depth_conv, name="weight")
536
+ remove_weight_norm(self.point_conv, name="weight")
537
+
538
+
539
+ def weight_norm_modules(module, name="weight", dim=0):
540
+ if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
541
+ module, Depthwise_Separable_TransposeConv1D
542
+ ):
543
+ module.weight_norm()
544
+ return module
545
+ else:
546
+ return weight_norm(module, name, dim)
547
+
548
+
549
+ def remove_weight_norm_modules(module, name="weight"):
550
+ if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
551
+ module, Depthwise_Separable_TransposeConv1D
552
+ ):
553
+ module.remove_weight_norm()
554
+ else:
555
+ remove_weight_norm(module, name)
556
+
557
+
558
+ class FFT(nn.Module):
559
+ def __init__(
560
+ self,
561
+ hidden_channels,
562
+ filter_channels,
563
+ n_heads,
564
+ n_layers=1,
565
+ kernel_size=1,
566
+ p_dropout=0.0,
567
+ proximal_bias=False,
568
+ proximal_init=True,
569
+ isflow=False,
570
+ **kwargs
571
+ ):
572
+ super().__init__()
573
+ self.hidden_channels = hidden_channels
574
+ self.filter_channels = filter_channels
575
+ self.n_heads = n_heads
576
+ self.n_layers = n_layers
577
+ self.kernel_size = kernel_size
578
+ self.p_dropout = p_dropout
579
+ self.proximal_bias = proximal_bias
580
+ self.proximal_init = proximal_init
581
+ if isflow:
582
+ cond_layer = torch.nn.Conv1d(
583
+ kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
584
+ )
585
+ self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
586
+ self.cond_layer = weight_norm_modules(cond_layer, name="weight")
587
+ self.gin_channels = kwargs["gin_channels"]
588
+ self.drop = nn.Dropout(p_dropout)
589
+ self.self_attn_layers = nn.ModuleList()
590
+ self.norm_layers_0 = nn.ModuleList()
591
+ self.ffn_layers = nn.ModuleList()
592
+ self.norm_layers_1 = nn.ModuleList()
593
+ for i in range(self.n_layers):
594
+ self.self_attn_layers.append(
595
+ MultiHeadAttention(
596
+ hidden_channels,
597
+ hidden_channels,
598
+ n_heads,
599
+ p_dropout=p_dropout,
600
+ proximal_bias=proximal_bias,
601
+ proximal_init=proximal_init,
602
+ )
603
+ )
604
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
605
+ self.ffn_layers.append(
606
+ FFN(
607
+ hidden_channels,
608
+ hidden_channels,
609
+ filter_channels,
610
+ kernel_size,
611
+ p_dropout=p_dropout,
612
+ causal=True,
613
+ )
614
+ )
615
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
616
+
617
+ def forward(self, x, x_mask, g=None):
618
+ """
619
+ x: decoder input
620
+ h: encoder output
621
+ """
622
+ if g is not None:
623
+ g = self.cond_layer(g)
624
+
625
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
626
+ device=x.device, dtype=x.dtype
627
+ )
628
+ x = x * x_mask
629
+ for i in range(self.n_layers):
630
+ if g is not None:
631
+ x = self.cond_pre(x)
632
+ cond_offset = i * 2 * self.hidden_channels
633
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
634
+ x = commons.fused_add_tanh_sigmoid_multiply(
635
+ x, g_l, torch.IntTensor([self.hidden_channels])
636
+ )
637
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
638
+ y = self.drop(y)
639
+ x = self.norm_layers_0[i](x + y)
640
+
641
+ y = self.ffn_layers[i](x, x_mask)
642
+ y = self.drop(y)
643
+ x = self.norm_layers_1[i](x + y)
644
+ x = x * x_mask
645
+ return x
646
+
647
+
648
+ class TransformerCouplingLayer(nn.Module):
649
+ def __init__(
650
+ self,
651
+ channels,
652
+ hidden_channels,
653
+ kernel_size,
654
+ n_layers,
655
+ n_heads,
656
+ p_dropout=0,
657
+ filter_channels=0,
658
+ mean_only=False,
659
+ wn_sharing_parameter=None,
660
+ gin_channels=0,
661
+ ):
662
+ assert channels % 2 == 0, "channels should be divisible by 2"
663
+ super().__init__()
664
+ self.channels = channels
665
+ self.hidden_channels = hidden_channels
666
+ self.kernel_size = kernel_size
667
+ self.n_layers = n_layers
668
+ self.half_channels = channels // 2
669
+ self.mean_only = mean_only
670
+
671
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
672
+ self.enc = (
673
+ Encoder(
674
+ hidden_channels,
675
+ filter_channels,
676
+ n_heads,
677
+ n_layers,
678
+ kernel_size,
679
+ p_dropout,
680
+ isflow=True,
681
+ gin_channels=gin_channels,
682
+ )
683
+ if wn_sharing_parameter is None
684
+ else wn_sharing_parameter
685
+ )
686
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
687
+ self.post.weight.data.zero_()
688
+ self.post.bias.data.zero_()
689
+
690
+ def forward(self, x, x_mask, g=None, reverse=False):
691
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
692
+ h = self.pre(x0) * x_mask
693
+ h = self.enc(h, x_mask, g=g)
694
+ stats = self.post(h) * x_mask
695
+ if not self.mean_only:
696
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
697
+ else:
698
+ m = stats
699
+ logs = torch.zeros_like(m)
700
+
701
+ if not reverse:
702
+ x1 = m + x1 * torch.exp(logs) * x_mask
703
+ x = torch.cat([x0, x1], 1)
704
+ logdet = torch.sum(logs, [1, 2])
705
+ return x, logdet
706
+ else:
707
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
708
+ x = torch.cat([x0, x1], 1)
709
+ return x
module/attentions_onnx.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from module import commons
7
+ from module.modules import LayerNorm
8
+
9
+
10
+ class LayerNorm(nn.Module):
11
+ def __init__(self, channels, eps=1e-5):
12
+ super().__init__()
13
+ self.channels = channels
14
+ self.eps = eps
15
+
16
+ self.gamma = nn.Parameter(torch.ones(channels))
17
+ self.beta = nn.Parameter(torch.zeros(channels))
18
+
19
+ def forward(self, x):
20
+ x = x.transpose(1, -1)
21
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
22
+ return x.transpose(1, -1)
23
+
24
+
25
+ @torch.jit.script
26
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
27
+ n_channels_int = n_channels[0]
28
+ in_act = input_a + input_b
29
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
30
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
31
+ acts = t_act * s_act
32
+ return acts
33
+
34
+
35
+ class Encoder(nn.Module):
36
+ def __init__(
37
+ self,
38
+ hidden_channels,
39
+ filter_channels,
40
+ n_heads,
41
+ n_layers,
42
+ kernel_size=1,
43
+ p_dropout=0.0,
44
+ window_size=4,
45
+ isflow=True,
46
+ **kwargs
47
+ ):
48
+ super().__init__()
49
+ self.hidden_channels = hidden_channels
50
+ self.filter_channels = filter_channels
51
+ self.n_heads = n_heads
52
+ self.n_layers = n_layers
53
+ self.kernel_size = kernel_size
54
+ self.p_dropout = p_dropout
55
+ self.window_size = window_size
56
+ # if isflow:
57
+ # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
58
+ # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
59
+ # self.cond_layer = weight_norm(cond_layer, name='weight')
60
+ # self.gin_channels = 256
61
+ self.cond_layer_idx = self.n_layers
62
+ if "gin_channels" in kwargs:
63
+ self.gin_channels = kwargs["gin_channels"]
64
+ if self.gin_channels != 0:
65
+ self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
66
+ # vits2 says 3rd block, so idx is 2 by default
67
+ self.cond_layer_idx = (
68
+ kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
69
+ )
70
+ logging.debug(self.gin_channels, self.cond_layer_idx)
71
+ assert (
72
+ self.cond_layer_idx < self.n_layers
73
+ ), "cond_layer_idx should be less than n_layers"
74
+ self.drop = nn.Dropout(p_dropout)
75
+ self.attn_layers = nn.ModuleList()
76
+ self.norm_layers_1 = nn.ModuleList()
77
+ self.ffn_layers = nn.ModuleList()
78
+ self.norm_layers_2 = nn.ModuleList()
79
+ for i in range(self.n_layers):
80
+ self.attn_layers.append(
81
+ MultiHeadAttention(
82
+ hidden_channels,
83
+ hidden_channels,
84
+ n_heads,
85
+ p_dropout=p_dropout,
86
+ window_size=window_size,
87
+ )
88
+ )
89
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
90
+ self.ffn_layers.append(
91
+ FFN(
92
+ hidden_channels,
93
+ hidden_channels,
94
+ filter_channels,
95
+ kernel_size,
96
+ p_dropout=p_dropout,
97
+ )
98
+ )
99
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
100
+
101
+ def forward(self, x, x_mask, g=None):
102
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
103
+ x = x * x_mask
104
+ for i in range(self.n_layers):
105
+ if i == self.cond_layer_idx and g is not None:
106
+ g = self.spk_emb_linear(g.transpose(1, 2))
107
+ g = g.transpose(1, 2)
108
+ x = x + g
109
+ x = x * x_mask
110
+ y = self.attn_layers[i](x, x, attn_mask)
111
+ y = self.drop(y)
112
+ x = self.norm_layers_1[i](x + y)
113
+
114
+ y = self.ffn_layers[i](x, x_mask)
115
+ y = self.drop(y)
116
+ x = self.norm_layers_2[i](x + y)
117
+ x = x * x_mask
118
+ return x
119
+
120
+
121
+ class MultiHeadAttention(nn.Module):
122
+ def __init__(
123
+ self,
124
+ channels,
125
+ out_channels,
126
+ n_heads,
127
+ p_dropout=0.0,
128
+ window_size=None,
129
+ heads_share=True,
130
+ block_length=None,
131
+ proximal_bias=False,
132
+ proximal_init=False,
133
+ ):
134
+ super().__init__()
135
+ assert channels % n_heads == 0
136
+
137
+ self.channels = channels
138
+ self.out_channels = out_channels
139
+ self.n_heads = n_heads
140
+ self.p_dropout = p_dropout
141
+ self.window_size = window_size
142
+ self.heads_share = heads_share
143
+ self.block_length = block_length
144
+ self.proximal_bias = proximal_bias
145
+ self.proximal_init = proximal_init
146
+ self.attn = None
147
+
148
+ self.k_channels = channels // n_heads
149
+ self.conv_q = nn.Conv1d(channels, channels, 1)
150
+ self.conv_k = nn.Conv1d(channels, channels, 1)
151
+ self.conv_v = nn.Conv1d(channels, channels, 1)
152
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
153
+ self.drop = nn.Dropout(p_dropout)
154
+
155
+ if window_size is not None:
156
+ n_heads_rel = 1 if heads_share else n_heads
157
+ rel_stddev = self.k_channels**-0.5
158
+ self.emb_rel_k = nn.Parameter(
159
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
160
+ * rel_stddev
161
+ )
162
+ self.emb_rel_v = nn.Parameter(
163
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
164
+ * rel_stddev
165
+ )
166
+
167
+ nn.init.xavier_uniform_(self.conv_q.weight)
168
+ nn.init.xavier_uniform_(self.conv_k.weight)
169
+ nn.init.xavier_uniform_(self.conv_v.weight)
170
+ if proximal_init:
171
+ with torch.no_grad():
172
+ self.conv_k.weight.copy_(self.conv_q.weight)
173
+ self.conv_k.bias.copy_(self.conv_q.bias)
174
+
175
+ def forward(self, x, c, attn_mask=None):
176
+ q = self.conv_q(x)
177
+ k = self.conv_k(c)
178
+ v = self.conv_v(c)
179
+
180
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
181
+
182
+ x = self.conv_o(x)
183
+ return x
184
+
185
+ def attention(self, query, key, value, mask=None):
186
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
187
+ b, d, t_s, _ = (*key.size(), query.size(2))
188
+ query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
189
+ key = key.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
190
+ value = value.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
191
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
192
+
193
+ if self.window_size is not None:
194
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
195
+ rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
196
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
197
+ scores = scores + scores_local
198
+
199
+ if mask is not None:
200
+ scores = scores.masked_fill(mask == 0, -1e4)
201
+
202
+ p_attn = F.softmax(scores, dim=-1)
203
+ p_attn = self.drop(p_attn)
204
+ output = torch.matmul(p_attn, value)
205
+
206
+ if self.window_size is not None:
207
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
208
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
209
+ output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
210
+
211
+ output = (output.transpose(2, 3).contiguous().view(b, d, -1))
212
+ return output, p_attn
213
+
214
+ def _matmul_with_relative_values(self, x, y):
215
+ """
216
+ x: [b, h, l, m]
217
+ y: [h or 1, m, d]
218
+ ret: [b, h, l, d]
219
+ """
220
+ ret = torch.matmul(x, y.unsqueeze(0))
221
+ return ret
222
+
223
+ def _matmul_with_relative_keys(self, x, y):
224
+ """
225
+ x: [b, h, l, d]
226
+ y: [h or 1, m, d]
227
+ ret: [b, h, l, m]
228
+ """
229
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
230
+ return ret
231
+
232
+ def _get_relative_embeddings(self, relative_embeddings, length):
233
+ max_relative_position = 2 * self.window_size + 1
234
+ # Pad first before slice to avoid using cond ops.
235
+ pad_l = torch.zeros((1), dtype = torch.int64) + length - (self.window_size + 1)
236
+ pad_s = torch.zeros((1), dtype = torch.int64) + (self.window_size + 1) - length
237
+ pad_length = torch.max(pad_l, other=torch.zeros((1), dtype = torch.int64))
238
+ slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype = torch.int64))
239
+
240
+ slice_end_position = slice_start_position + 2 * length - 1
241
+ padded_relative_embeddings = F.pad(
242
+ relative_embeddings,
243
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
244
+ )
245
+ used_relative_embeddings = padded_relative_embeddings[
246
+ :, slice_start_position:slice_end_position
247
+ ]
248
+ return used_relative_embeddings
249
+
250
+ def _relative_position_to_absolute_position(self, x):
251
+ """
252
+ x: [b, h, l, 2*l-1]
253
+ ret: [b, h, l, l]
254
+ """
255
+ batch, heads, length, _ = x.size()
256
+ # Concat columns of pad to shift from relative to absolute indexing.
257
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
258
+
259
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
260
+ x_flat = x.view([batch, heads, length * 2 * length])
261
+ x_flat = F.pad(
262
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
263
+ )
264
+
265
+ # Reshape and slice out the padded elements.
266
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
267
+ :, :, :length, length - 1 :
268
+ ]
269
+ return x_final
270
+
271
+ def _absolute_position_to_relative_position(self, x):
272
+ """
273
+ x: [b, h, l, l]
274
+ ret: [b, h, l, 2*l-1]
275
+ """
276
+ batch, heads, length, _ = x.size()
277
+ # padd along column
278
+ x = F.pad(
279
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
280
+ )
281
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
282
+ # add 0's in the beginning that will skew the elements after reshape
283
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
284
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
285
+ return x_final
286
+
287
+ def _attention_bias_proximal(self, length):
288
+ """Bias for self-attention to encourage attention to close positions.
289
+ Args:
290
+ length: an integer scalar.
291
+ Returns:
292
+ a Tensor with shape [1, 1, length, length]
293
+ """
294
+ r = torch.arange(length, dtype=torch.float32)
295
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
296
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
297
+
298
+
299
+ class FFN(nn.Module):
300
+ def __init__(
301
+ self,
302
+ in_channels,
303
+ out_channels,
304
+ filter_channels,
305
+ kernel_size,
306
+ p_dropout=0.0,
307
+ activation=None,
308
+ causal=False,
309
+ ):
310
+ super().__init__()
311
+ self.in_channels = in_channels
312
+ self.out_channels = out_channels
313
+ self.filter_channels = filter_channels
314
+ self.kernel_size = kernel_size
315
+ self.p_dropout = p_dropout
316
+ self.activation = activation
317
+ self.causal = causal
318
+
319
+ if causal:
320
+ self.padding = self._causal_padding
321
+ else:
322
+ self.padding = self._same_padding
323
+
324
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
325
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
326
+ self.drop = nn.Dropout(p_dropout)
327
+
328
+ def forward(self, x, x_mask):
329
+ x = self.conv_1(self.padding(x * x_mask))
330
+ if self.activation == "gelu":
331
+ x = x * torch.sigmoid(1.702 * x)
332
+ else:
333
+ x = torch.relu(x)
334
+ x = self.drop(x)
335
+ x = self.conv_2(self.padding(x * x_mask))
336
+ return x * x_mask
337
+
338
+ def _causal_padding(self, x):
339
+ if self.kernel_size == 1:
340
+ return x
341
+ pad_l = self.kernel_size - 1
342
+ pad_r = 0
343
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
344
+ x = F.pad(x, commons.convert_pad_shape(padding))
345
+ return x
346
+
347
+ def _same_padding(self, x):
348
+ if self.kernel_size == 1:
349
+ return x
350
+ pad_l = (self.kernel_size - 1) // 2
351
+ pad_r = self.kernel_size // 2
352
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
353
+ x = F.pad(x, commons.convert_pad_shape(padding))
354
+ return x
module/commons.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def init_weights(m, mean=0.0, std=0.01):
7
+ classname = m.__class__.__name__
8
+ if classname.find("Conv") != -1:
9
+ m.weight.data.normal_(mean, std)
10
+
11
+
12
+ def get_padding(kernel_size, dilation=1):
13
+ return int((kernel_size * dilation - dilation) / 2)
14
+
15
+
16
+ def convert_pad_shape(pad_shape):
17
+ l = pad_shape[::-1]
18
+ pad_shape = [item for sublist in l for item in sublist]
19
+ return pad_shape
20
+
21
+
22
+ def intersperse(lst, item):
23
+ result = [item] * (len(lst) * 2 + 1)
24
+ result[1::2] = lst
25
+ return result
26
+
27
+
28
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
29
+ """KL(P||Q)"""
30
+ kl = (logs_q - logs_p) - 0.5
31
+ kl += (
32
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
33
+ )
34
+ return kl
35
+
36
+
37
+ def rand_gumbel(shape):
38
+ """Sample from the Gumbel distribution, protect from overflows."""
39
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40
+ return -torch.log(-torch.log(uniform_samples))
41
+
42
+
43
+ def rand_gumbel_like(x):
44
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45
+ return g
46
+
47
+
48
+ def slice_segments(x, ids_str, segment_size=4):
49
+ ret = torch.zeros_like(x[:, :, :segment_size])
50
+ for i in range(x.size(0)):
51
+ idx_str = ids_str[i]
52
+ idx_end = idx_str + segment_size
53
+ ret[i] = x[i, :, idx_str:idx_end]
54
+ return ret
55
+
56
+
57
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
58
+ b, d, t = x.size()
59
+ if x_lengths is None:
60
+ x_lengths = t
61
+ ids_str_max = x_lengths - segment_size + 1
62
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63
+ ret = slice_segments(x, ids_str, segment_size)
64
+ return ret, ids_str
65
+
66
+
67
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
68
+ position = torch.arange(length, dtype=torch.float)
69
+ num_timescales = channels // 2
70
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
71
+ num_timescales - 1
72
+ )
73
+ inv_timescales = min_timescale * torch.exp(
74
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
75
+ )
76
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
77
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
78
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
79
+ signal = signal.view(1, channels, length)
80
+ return signal
81
+
82
+
83
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
84
+ b, channels, length = x.size()
85
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
86
+ return x + signal.to(dtype=x.dtype, device=x.device)
87
+
88
+
89
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
90
+ b, channels, length = x.size()
91
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
92
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
93
+
94
+
95
+ def subsequent_mask(length):
96
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
97
+ return mask
98
+
99
+
100
+ @torch.jit.script
101
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
102
+ n_channels_int = n_channels[0]
103
+ in_act = input_a + input_b
104
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
105
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
106
+ acts = t_act * s_act
107
+ return acts
108
+
109
+
110
+ def convert_pad_shape(pad_shape):
111
+ l = pad_shape[::-1]
112
+ pad_shape = [item for sublist in l for item in sublist]
113
+ return pad_shape
114
+
115
+
116
+ def shift_1d(x):
117
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
118
+ return x
119
+
120
+
121
+ def sequence_mask(length, max_length=None):
122
+ if max_length is None:
123
+ max_length = length.max()
124
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
125
+ return x.unsqueeze(0) < length.unsqueeze(1)
126
+
127
+
128
+ def generate_path(duration, mask):
129
+ """
130
+ duration: [b, 1, t_x]
131
+ mask: [b, 1, t_y, t_x]
132
+ """
133
+ device = duration.device
134
+
135
+ b, _, t_y, t_x = mask.shape
136
+ cum_duration = torch.cumsum(duration, -1)
137
+
138
+ cum_duration_flat = cum_duration.view(b * t_x)
139
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
140
+ path = path.view(b, t_x, t_y)
141
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
142
+ path = path.unsqueeze(1).transpose(2, 3) * mask
143
+ return path
144
+
145
+
146
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
147
+ if isinstance(parameters, torch.Tensor):
148
+ parameters = [parameters]
149
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
150
+ norm_type = float(norm_type)
151
+ if clip_value is not None:
152
+ clip_value = float(clip_value)
153
+
154
+ total_norm = 0
155
+ for p in parameters:
156
+ param_norm = p.grad.data.norm(norm_type)
157
+ total_norm += param_norm.item() ** norm_type
158
+ if clip_value is not None:
159
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
160
+ total_norm = total_norm ** (1.0 / norm_type)
161
+ return total_norm
162
+
163
+
164
+ def squeeze(x, x_mask=None, n_sqz=2):
165
+ b, c, t = x.size()
166
+
167
+ t = (t // n_sqz) * n_sqz
168
+ x = x[:, :, :t]
169
+ x_sqz = x.view(b, c, t // n_sqz, n_sqz)
170
+ x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
171
+
172
+ if x_mask is not None:
173
+ x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz]
174
+ else:
175
+ x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
176
+ return x_sqz * x_mask, x_mask
177
+
178
+
179
+ def unsqueeze(x, x_mask=None, n_sqz=2):
180
+ b, c, t = x.size()
181
+
182
+ x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
183
+ x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
184
+
185
+ if x_mask is not None:
186
+ x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
187
+ else:
188
+ x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
189
+ return x_unsqz * x_mask, x_mask