Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +567 -0
- gitattributes +35 -0
- requirements.txt +14 -0
- soulxpodcast/__init__.py +0 -0
- soulxpodcast/__pycache__/__init__.cpython-311.pyc +0 -0
- soulxpodcast/__pycache__/__init__.cpython-312.pyc +0 -0
- soulxpodcast/__pycache__/config.cpython-311.pyc +0 -0
- soulxpodcast/__pycache__/config.cpython-312.pyc +0 -0
- soulxpodcast/cli/__pycache__/soulxpodcast.cpython-311.pyc +0 -0
- soulxpodcast/cli/__pycache__/soulxpodcast.cpython-312.pyc +0 -0
- soulxpodcast/cli/engine_test.py +74 -0
- soulxpodcast/cli/soulxpodcast.py +273 -0
- soulxpodcast/config.py +141 -0
- soulxpodcast/engine/__init__.py +0 -0
- soulxpodcast/engine/__pycache__/__init__.cpython-311.pyc +0 -0
- soulxpodcast/engine/__pycache__/__init__.cpython-312.pyc +0 -0
- soulxpodcast/engine/__pycache__/llm_engine.cpython-311.pyc +0 -0
- soulxpodcast/engine/__pycache__/llm_engine.cpython-312.pyc +0 -0
- soulxpodcast/engine/llm_engine.py +116 -0
- soulxpodcast/models/__pycache__/soulxpodcast.cpython-311.pyc +0 -0
- soulxpodcast/models/__pycache__/soulxpodcast.cpython-312.pyc +0 -0
- soulxpodcast/models/modules/__init__.py +0 -0
- soulxpodcast/models/modules/__pycache__/__init__.cpython-311.pyc +0 -0
- soulxpodcast/models/modules/__pycache__/__init__.cpython-312.pyc +0 -0
- soulxpodcast/models/modules/__pycache__/flow.cpython-311.pyc +0 -0
- soulxpodcast/models/modules/__pycache__/flow.cpython-312.pyc +0 -0
- soulxpodcast/models/modules/__pycache__/hifigan.cpython-311.pyc +0 -0
- soulxpodcast/models/modules/__pycache__/hifigan.cpython-312.pyc +0 -0
- soulxpodcast/models/modules/__pycache__/sampler.cpython-311.pyc +0 -0
- soulxpodcast/models/modules/__pycache__/sampler.cpython-312.pyc +0 -0
- soulxpodcast/models/modules/flow.py +197 -0
- soulxpodcast/models/modules/flow_components/__init__.py +0 -0
- soulxpodcast/models/modules/flow_components/__pycache__/__init__.cpython-311.pyc +0 -0
- soulxpodcast/models/modules/flow_components/__pycache__/__init__.cpython-312.pyc +0 -0
- soulxpodcast/models/modules/flow_components/__pycache__/estimator.cpython-311.pyc +0 -0
- soulxpodcast/models/modules/flow_components/__pycache__/estimator.cpython-312.pyc +0 -0
- soulxpodcast/models/modules/flow_components/__pycache__/upsample_encoder.cpython-311.pyc +0 -0
- soulxpodcast/models/modules/flow_components/__pycache__/upsample_encoder.cpython-312.pyc +0 -0
- soulxpodcast/models/modules/flow_components/estimator.py +974 -0
- soulxpodcast/models/modules/flow_components/upsample_encoder.py +998 -0
- soulxpodcast/models/modules/hifigan.py +249 -0
- soulxpodcast/models/modules/hifigan_components/__init__.py +0 -0
- soulxpodcast/models/modules/hifigan_components/__pycache__/__init__.cpython-311.pyc +0 -0
- soulxpodcast/models/modules/hifigan_components/__pycache__/__init__.cpython-312.pyc +0 -0
- soulxpodcast/models/modules/hifigan_components/__pycache__/layers.cpython-311.pyc +0 -0
- soulxpodcast/models/modules/hifigan_components/__pycache__/layers.cpython-312.pyc +0 -0
- soulxpodcast/models/modules/hifigan_components/layers.py +433 -0
- soulxpodcast/models/modules/sampler.py +221 -0
- soulxpodcast/models/soulxpodcast.py +192 -0
- soulxpodcast/utils/__init__.py +0 -0
app.py
ADDED
|
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from argparse import ArgumentParser
|
| 5 |
+
from typing import Literal, List, Tuple
|
| 6 |
+
import sys
|
| 7 |
+
import importlib.util
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
import spaces
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np # 确保导入 numpy
|
| 12 |
+
import random # 确保导入 random
|
| 13 |
+
import s3tokenizer
|
| 14 |
+
|
| 15 |
+
from soulxpodcast.models.soulxpodcast import SoulXPodcast
|
| 16 |
+
from soulxpodcast.config import Config, SoulXPodcastLLMConfig, SamplingParams
|
| 17 |
+
from soulxpodcast.utils.dataloader import (
|
| 18 |
+
PodcastInferHandler,
|
| 19 |
+
SPK_DICT, TEXT_START, TEXT_END, AUDIO_START, TASK_PODCAST
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# ================================================
|
| 23 |
+
# 示例音频路径
|
| 24 |
+
# ================================================
|
| 25 |
+
S1_PROMPT_WAV = "assets/audios/female_mandarin.wav" # 示例路径
|
| 26 |
+
S2_PROMPT_WAV = "assets/audios/male_mandarin.wav" # 示例路径
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# ================================================
|
| 30 |
+
# 示例数据 (gr.Examples)
|
| 31 |
+
# ================================================
|
| 32 |
+
EXAMPLES_LIST = [
|
| 33 |
+
# 示例 1:清空所有
|
| 34 |
+
[
|
| 35 |
+
None, "", "", None, "", "", ""
|
| 36 |
+
],
|
| 37 |
+
# 示例 2:普通播客
|
| 38 |
+
[
|
| 39 |
+
S1_PROMPT_WAV,
|
| 40 |
+
"喜欢攀岩、徒步、滑雪的语言爱好者,以及过两天要带着全部家当去景德镇做陶瓷的白日梦想家。",
|
| 41 |
+
"",
|
| 42 |
+
S2_PROMPT_WAV,
|
| 43 |
+
"呃,还有一个就是要跟大家纠正一点,就是我们在看电影的时候,尤其是游戏玩家,看电影的时候,在看到那个到西北那边的这个陕北民谣,嗯,这个可能在想,哎,是不是他是受到了黑神话的启发?",
|
| 44 |
+
"",
|
| 45 |
+
"[S1] 哈喽,AI时代的冲浪先锋们!欢迎收听《AI生活进行时》。啊,一个充满了未来感,然后,还有一点点,<|laughter|>神经质的播客节目,我是主持人小希。\n[S2] 哎,大家好呀!我是能唠,爱唠,天天都想唠的唠嗑!\n[S1] 最近活得特别赛博朋克哈!以前老是觉得AI是科幻片儿里的,<|sigh|> 现在,现在连我妈都用AI写广场舞文案了。\n[S2] 这个例子很生动啊。是的,特别是生成式AI哈,感觉都要炸了! 诶,那我们今天就聊聊AI是怎么走进我们的生活的哈!",
|
| 46 |
+
],
|
| 47 |
+
# 示例 3:四川播客
|
| 48 |
+
[
|
| 49 |
+
S1_PROMPT_WAV,
|
| 50 |
+
"喜欢攀岩、徒步、滑雪的语言爱好者,以及过两天要带着全部家当去景德镇做陶瓷的白日梦想家。",
|
| 51 |
+
"<|Sichuan|>要得要得!前头几个耍洋盘,我后脚就背起铺盖卷去景德镇耍泥巴,巴适得喊老天爷!",
|
| 52 |
+
S2_PROMPT_WAV,
|
| 53 |
+
"呃,还有一个就是要跟大家纠正一点,就是我们在看电影的时候,尤其是游戏玩家,看电影的时候,在看到那个到西北那边的这个陕北民谣,嗯,这个可能在想,哎,是不是他是受到了黑神话的启发?",
|
| 54 |
+
"<|Sichuan|>哎哟喂,这个搞反了噻!黑神话里头唱曲子的王二浪早八百年就在黄土高坡吼秦腔喽,游戏组专门跑切录的原汤原水,听得人汗毛儿都立起来!",
|
| 55 |
+
"[S1] <|Sichuan|>各位《巴适得板》的听众些,大家好噻!我是你们主持人晶晶。今儿天气硬是巴适,不晓得大家是在赶路嘛,还是茶都泡起咯,准备跟我们好生摆一哈龙门阵喃?\n[S2] <|Sichuan|>晶晶好哦,大家安逸噻!我是李老倌。你刚开口就川味十足,摆龙门阵几个字一甩出来,我鼻子头都闻到茶香跟火锅香咯!\n[S1] <|Sichuan|>就是得嘛!李老倌,我前些天带个外地朋友切人民公园鹤鸣茶社坐了一哈。他硬是搞不醒豁,为啥子我们一堆人围到杯茶就可以吹一下午壳子,从隔壁子王嬢嬢娃儿耍朋友,扯到美国大选,中间还掺几盘斗地主。他说我们四川人简直是把摸鱼刻进骨子里头咯!\n[S2] <|Sichuan|>你那个朋友说得倒是有点儿趣,但他莫看到精髓噻。摆龙门阵哪是摸鱼嘛,这是我们川渝人特有的交际方式,更是一种活法。外省人天天说的松弛感,根根儿就在这龙门阵里头。今天我们就要好生摆一哈,为啥子四川人活得这么舒坦。就先从茶馆这个老窝子说起,看它咋个成了我们四川人的魂儿!",
|
| 56 |
+
],
|
| 57 |
+
# 示例 4:粤语播客
|
| 58 |
+
[
|
| 59 |
+
S1_PROMPT_WAV,
|
| 60 |
+
"喜欢攀岩、徒步、滑雪的语言爱好者,以及过两天要带着全部家当去景德镇做陶瓷的白日梦想家。",
|
| 61 |
+
"<|Yue|>真係冇讲错啊!攀山滑雪嘅语言专家几巴闭,都唔及我听日拖成副身家去景德镇玩泥巴,呢铺真系发哂白日梦咯!",
|
| 62 |
+
S2_PROMPT_WAV,
|
| 63 |
+
"呃,还有一个就是要跟大家纠正一点,就是我们在看电影的时候,尤其是游戏玩家,看电影的时候,在看到那个到西北那边的这个陕北民谣,嗯,这个可能在想,哎,是不是他是受到了黑神话的启发?",
|
| 64 |
+
"<|Yue|>咪搞错啊!陕北民谣响度唱咗几十年,��神话边有咁大面啊?你估佢哋抄游戏咩!",
|
| 65 |
+
"[S1] <|Yue|>哈囉大家好啊,歡迎收聽我哋嘅節目。喂,我今日想問你樣嘢啊,你覺唔覺得,嗯,而家揸電動車,最煩,最煩嘅一樣嘢係咩啊?\n[S2] <|Yue|>梗係充電啦。大佬啊,搵個位都已經好煩,搵到個位仲要喺度等,你話快極都要半個鐘一個鐘,真係,有時諗起都覺得好冇癮。\n[S1] <|Yue|>係咪先。如果我而家同你講,充電可以快到同入油差唔多時間,你信唔信先?喂你平時喺油站入滿一缸油,要幾耐啊?五六分鐘?\n[S2] <|Yue|>差唔多啦,七八分鐘,點都走得啦。電車喎,可以做到咁快?你咪玩啦。",
|
| 66 |
+
],
|
| 67 |
+
# 示例 5:河南播客
|
| 68 |
+
[
|
| 69 |
+
S1_PROMPT_WAV,
|
| 70 |
+
"喜欢攀岩、徒步、滑雪的语言爱好者,以及过两天要带着全部家当去景德镇做陶瓷的白日梦想家。",
|
| 71 |
+
"<|Henan|>俺这不是怕恁路上不得劲儿嘛!那景德镇瓷泥可娇贵着哩,得先拿咱河南人这实诚劲儿给它揉透喽。",
|
| 72 |
+
S2_PROMPT_WAV,
|
| 73 |
+
"呃,还有一个就是要跟大家纠正一点,就是我们在看电影的时候,尤其是游戏玩家,看电影的时候,在看到那个到西北那边的这个陕北民谣,嗯,这个可能在想,哎,是不是他是受到了黑神话的启发?",
|
| 74 |
+
"<|Henan|>恁这想法真闹挺!陕北民谣比黑神话早几百年都有了,咱可不兴这弄颠倒啊,中不?恁这想法真闹挺!那陕北民谣在黄土高坡响了几百年,咋能说是跟黑神话学的咧?咱得把这事儿捋直喽,中不中!",
|
| 75 |
+
"[S1] <|Henan|>哎,大家好啊,欢迎收听咱这一期嘞《瞎聊呗,就这么说》,我是恁嘞老朋友,燕子。\n[S2] <|Henan|>大家好,我是老张。燕子啊,今儿瞅瞅你这个劲儿,咋着,是有啥可得劲嘞事儿想跟咱唠唠?\n[S1] <|Henan|>哎哟,老张,你咋恁懂我嘞!我跟你说啊,最近我刷手机,老是刷住些可逗嘞方言视频,特别是咱河南话,咦~我哩个乖乖,一听我都憋不住笑,咋说嘞,得劲儿哩很,跟回到家一样。\n[S2] <|Henan|>你这回可算说到根儿上了!河南话,咱往大处说说,中原官话,它真嘞是有一股劲儿搁里头。它可不光是说话,它脊梁骨后头藏嘞,是咱一整套、鲜鲜活活嘞过法儿,一种活人嘞道理。\n[S1] <|Henan|>活人嘞道理?哎,这你这一说,我嘞兴致“腾”一下就上来啦!觉住咱这嗑儿,一下儿从搞笑视频蹿到文化顶上了。那你赶紧给我白话白话,这里头到底有啥道道儿?我特别想知道——为啥一提起咱河南人,好些人脑子里“蹦”出来嘞头一个词儿,就是实在?这个实在,骨子里到底是啥嘞?",
|
| 76 |
+
],
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# ================================================
|
| 81 |
+
# SoulX-Podcast Model
|
| 82 |
+
# ================================================
|
| 83 |
+
model: SoulXPodcast = None
|
| 84 |
+
dataset: PodcastInferHandler = None
|
| 85 |
+
def initiate_model(config: Config, enable_tn: bool=False):
|
| 86 |
+
global model
|
| 87 |
+
if model is None:
|
| 88 |
+
model = SoulXPodcast(config)
|
| 89 |
+
|
| 90 |
+
global dataset
|
| 91 |
+
if dataset is None:
|
| 92 |
+
dataset = PodcastInferHandler(model.llm.tokenizer, None, config)
|
| 93 |
+
|
| 94 |
+
# ================================================
|
| 95 |
+
# Gradio
|
| 96 |
+
# ================================================
|
| 97 |
+
|
| 98 |
+
_i18n_key2lang_dict = dict(
|
| 99 |
+
# Speaker1 Prompt
|
| 100 |
+
spk1_prompt_audio_label=dict(
|
| 101 |
+
en="Speaker 1 Prompt Audio",
|
| 102 |
+
zh="说话人 1 参考语音",
|
| 103 |
+
),
|
| 104 |
+
spk1_prompt_text_label=dict(
|
| 105 |
+
en="Speaker 1 Prompt Text",
|
| 106 |
+
zh="说话人 1 参考文本",
|
| 107 |
+
),
|
| 108 |
+
spk1_prompt_text_placeholder=dict(
|
| 109 |
+
en="text of speaker 1 Prompt audio.",
|
| 110 |
+
zh="说话人 1 参考文本",
|
| 111 |
+
),
|
| 112 |
+
spk1_prompt_cot_text_label=dict(
|
| 113 |
+
en="Speaker 1 Prompt COT Text",
|
| 114 |
+
zh="说话人 1 参考推理链文本",
|
| 115 |
+
),
|
| 116 |
+
spk1_prompt_cot_text_placeholder=dict(
|
| 117 |
+
en="Dialect prompt cot text with prefix: <|Sichuan|>/<|Yue|>/<|Henan|> ",
|
| 118 |
+
zh="带前缀方言提示词思维链文本,前缀如下:<|Sichuan|>/<|Yue|>/<|Henan|>,如:<|Sichuan|>走嘛,切吃那家新开的麻辣烫,听别个说味道硬是霸道得很,好吃到不摆了,去晚了还得排队!",
|
| 119 |
+
),
|
| 120 |
+
# Speaker2 Prompt
|
| 121 |
+
spk2_prompt_audio_label=dict(
|
| 122 |
+
en="Speaker 2 Prompt Audio",
|
| 123 |
+
zh="说话人 2 参考语音",
|
| 124 |
+
),
|
| 125 |
+
spk2_prompt_text_label=dict(
|
| 126 |
+
en="Speaker 2 Prompt Text",
|
| 127 |
+
zh="说话人 2 参考文本",
|
| 128 |
+
),
|
| 129 |
+
spk2_prompt_text_placeholder=dict(
|
| 130 |
+
en="[S2] text of speaker 2 prompt audio.",
|
| 131 |
+
zh="[S2] 说话人 2 参考文本",
|
| 132 |
+
),
|
| 133 |
+
spk2_prompt_cot_text_label=dict(
|
| 134 |
+
en="Speaker 2 Prompt COT Text",
|
| 135 |
+
zh="说话人 2 参考推理链文本",
|
| 136 |
+
),
|
| 137 |
+
spk2_prompt_cot_text_placeholder=dict(
|
| 138 |
+
en="Dialect prompt cot text with prefix: <|Sichuan|>/<|Yue|>/<|Henan|> ",
|
| 139 |
+
zh="带前缀方言提示词思维链文本,前缀如下:<|Sichuan|>/<|Yue|>/<|Henan|>,如:<|Sichuan|>走嘛,切吃那家新开的麻辣烫,听别个说味道硬是霸道得很,好吃到不摆了,去晚了还得排队!",
|
| 140 |
+
),
|
| 141 |
+
# Dialogue input textbox
|
| 142 |
+
dialogue_text_input_label=dict(
|
| 143 |
+
en="Dialogue Text Input",
|
| 144 |
+
zh="合成文本输入",
|
| 145 |
+
),
|
| 146 |
+
dialogue_text_input_placeholder=dict(
|
| 147 |
+
en="[S1]text[S2]text[S1]text...",
|
| 148 |
+
zh="[S1]文本[S2]文本[S1]文本...",
|
| 149 |
+
),
|
| 150 |
+
# Generate button
|
| 151 |
+
generate_btn_label=dict(
|
| 152 |
+
en="Generate Audio",
|
| 153 |
+
zh="合成",
|
| 154 |
+
),
|
| 155 |
+
# Generated audio
|
| 156 |
+
generated_audio_label=dict(
|
| 157 |
+
en="Generated Dialogue Audio",
|
| 158 |
+
zh="合成的对话音频",
|
| 159 |
+
),
|
| 160 |
+
# Warining1: invalid text for prompt
|
| 161 |
+
warn_invalid_spk1_prompt_text=dict(
|
| 162 |
+
en='Invalid speaker 1 prompt text, should not be empty and strictly follow: "xxx"',
|
| 163 |
+
zh='说话人 1 参考文本不合规,不能为空,格式:"xxx"',
|
| 164 |
+
),
|
| 165 |
+
# warn_invalid_spk1_prompt_cot_text=dict(
|
| 166 |
+
# en='Invalid speaker 1 prompt cot text, should not be empty and strictly follow: "[S1]xxx"',
|
| 167 |
+
# zh='说话人 1 参考文本不合规,格式:"[S1]xxx"',
|
| 168 |
+
# ),
|
| 169 |
+
warn_invalid_spk2_prompt_text=dict(
|
| 170 |
+
en='Invalid speaker 2 prompt text, should strictly follow: "[S2]xxx"',
|
| 171 |
+
zh='说话人 2 参考文本不合规,格式:"[S2]xxx"',
|
| 172 |
+
),
|
| 173 |
+
# Warining2: invalid text for dialogue input
|
| 174 |
+
warn_invalid_dialogue_text=dict(
|
| 175 |
+
en='Invalid dialogue input text, should strictly follow: "[S1]xxx[S2]xxx..."',
|
| 176 |
+
zh='对话文本输入不合规,格式:"[S1]xxx[S2]xxx..."',
|
| 177 |
+
),
|
| 178 |
+
# Warining3: incomplete prompt info
|
| 179 |
+
warn_incomplete_prompt=dict(
|
| 180 |
+
en="Please provide prompt audio and text for both speaker 1 and speaker 2",
|
| 181 |
+
zh="请提供说话人 1 与说话人 2 的参考语音与参考文本",
|
| 182 |
+
),
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
global_lang: Literal["zh", "en"] = "zh"
|
| 187 |
+
|
| 188 |
+
def i18n(key):
|
| 189 |
+
# (保持不变)
|
| 190 |
+
global global_lang
|
| 191 |
+
return _i18n_key2lang_dict[key][global_lang]
|
| 192 |
+
|
| 193 |
+
def check_monologue_text(text: str, prefix: str = None) -> bool:
|
| 194 |
+
text = text.strip()
|
| 195 |
+
# Check speaker tags
|
| 196 |
+
if prefix is not None and (not text.startswith(prefix)):
|
| 197 |
+
return False
|
| 198 |
+
# Remove prefix
|
| 199 |
+
if prefix is not None:
|
| 200 |
+
text = text.removeprefix(prefix)
|
| 201 |
+
text = text.strip()
|
| 202 |
+
# If empty?
|
| 203 |
+
if len(text) == 0:
|
| 204 |
+
return False
|
| 205 |
+
return True
|
| 206 |
+
|
| 207 |
+
def check_dialect_prompt_cot_text(text: str, prefix: str = None) -> bool:
|
| 208 |
+
text = text.strip()
|
| 209 |
+
# Check COT prefix tags
|
| 210 |
+
if prefix is not None and (not text.startswith(prefix)):
|
| 211 |
+
return False
|
| 212 |
+
text = text.strip()
|
| 213 |
+
# If empty?
|
| 214 |
+
if len(text) == 0:
|
| 215 |
+
return False
|
| 216 |
+
return True
|
| 217 |
+
|
| 218 |
+
def check_dialogue_text(text_list: List[str]) -> bool:
|
| 219 |
+
if len(text_list) == 0:
|
| 220 |
+
return False
|
| 221 |
+
for text in text_list:
|
| 222 |
+
if not (
|
| 223 |
+
check_monologue_text(text, "[S1]")
|
| 224 |
+
or check_monologue_text(text, "[S2]")
|
| 225 |
+
or check_monologue_text(text, "[S3]")
|
| 226 |
+
or check_monologue_text(text, "[S4]")
|
| 227 |
+
):
|
| 228 |
+
return False
|
| 229 |
+
return True
|
| 230 |
+
|
| 231 |
+
def process_single(target_text_list, prompt_wav_list, prompt_text_list, use_prompt_cot, prompt_cot_text_list):
|
| 232 |
+
spks, texts = [], []
|
| 233 |
+
for target_text in target_text_list:
|
| 234 |
+
pattern = r'(\[S[1-9]\])(.+)'
|
| 235 |
+
match = re.match(pattern, target_text)
|
| 236 |
+
text, spk = match.group(2), int(match.group(1)[2])-1
|
| 237 |
+
spks.append(spk)
|
| 238 |
+
texts.append(text)
|
| 239 |
+
|
| 240 |
+
global dataset
|
| 241 |
+
dataitem = {"key": "001", "prompt_text": prompt_text_list, "prompt_wav": prompt_wav_list,
|
| 242 |
+
"text": texts, "spk": spks, }
|
| 243 |
+
if use_prompt_cot:
|
| 244 |
+
dataitem.update({
|
| 245 |
+
"prompt_cot_text": prompt_cot_text_list
|
| 246 |
+
})
|
| 247 |
+
dataset.update_datasource(
|
| 248 |
+
[
|
| 249 |
+
dataitem
|
| 250 |
+
]
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# assert one data only;
|
| 254 |
+
data = dataset[0]
|
| 255 |
+
prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(data["log_mel"]) # [B, num_mels=128, T]
|
| 256 |
+
spk_emb_for_flow = torch.tensor(data["spk_emb"])
|
| 257 |
+
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(data["mel"], batch_first=True, padding_value=0) # [B, T', num_mels=80]
|
| 258 |
+
prompt_mels_lens_for_flow = torch.tensor(data['mel_len'])
|
| 259 |
+
text_tokens_for_llm = data["text_tokens"]
|
| 260 |
+
prompt_text_tokens_for_llm = data["prompt_text_tokens"]
|
| 261 |
+
spk_ids = data["spks_list"]
|
| 262 |
+
sampling_params = SamplingParams(use_ras=True,win_size=25,tau_r=0.2)
|
| 263 |
+
infos = [data["info"]]
|
| 264 |
+
processed_data = {
|
| 265 |
+
"prompt_mels_for_llm": prompt_mels_for_llm,
|
| 266 |
+
"prompt_mels_lens_for_llm": prompt_mels_lens_for_llm,
|
| 267 |
+
"prompt_text_tokens_for_llm": prompt_text_tokens_for_llm,
|
| 268 |
+
"text_tokens_for_llm": text_tokens_for_llm,
|
| 269 |
+
"prompt_mels_for_flow_ori": prompt_mels_for_flow,
|
| 270 |
+
"prompt_mels_lens_for_flow": prompt_mels_lens_for_flow,
|
| 271 |
+
"spk_emb_for_flow": spk_emb_for_flow,
|
| 272 |
+
"sampling_params": sampling_params,
|
| 273 |
+
"spk_ids": spk_ids,
|
| 274 |
+
"infos": infos,
|
| 275 |
+
"use_prompt_cot": use_prompt_cot,
|
| 276 |
+
}
|
| 277 |
+
if use_prompt_cot:
|
| 278 |
+
processed_data.update({
|
| 279 |
+
"prompt_cot_text_tokens_for_llm": data["prompt_cot_text_tokens"],
|
| 280 |
+
"prompt_cot_prefix": data["prompt_cot_prefix"],
|
| 281 |
+
})
|
| 282 |
+
return processed_data
|
| 283 |
+
|
| 284 |
+
@spaces.GPU
|
| 285 |
+
def dialogue_synthesis_function(
|
| 286 |
+
target_text: str,
|
| 287 |
+
spk1_prompt_text: str | None = "",
|
| 288 |
+
spk1_prompt_audio: str | None = None,
|
| 289 |
+
spk1_prompt_cot_text: str | None = "",
|
| 290 |
+
spk2_prompt_text: str | None = "",
|
| 291 |
+
spk2_prompt_audio: str | None = None,
|
| 292 |
+
spk2_prompt_cot_text: str | None = "",
|
| 293 |
+
seed: int = 1988, # <-- seed 参数保留
|
| 294 |
+
):
|
| 295 |
+
# ================== 设置随机种子 ==================
|
| 296 |
+
seed = int(seed)
|
| 297 |
+
torch.manual_seed(seed)
|
| 298 |
+
np.random.seed(seed)
|
| 299 |
+
random.seed(seed)
|
| 300 |
+
# ================================================
|
| 301 |
+
|
| 302 |
+
# Check prompt info
|
| 303 |
+
target_text_list: List[str] = re.findall(r"(\[S[0-9]\][^\[\]]*)", target_text)
|
| 304 |
+
target_text_list = [text.strip() for text in target_text_list]
|
| 305 |
+
if not check_dialogue_text(target_text_list):
|
| 306 |
+
gr.Warning(message=i18n("warn_invalid_dialogue_text"))
|
| 307 |
+
return None
|
| 308 |
+
|
| 309 |
+
# Go synthesis
|
| 310 |
+
progress_bar = gr.Progress(track_tqdm=True)
|
| 311 |
+
prompt_wav_list = [spk1_prompt_audio, spk2_prompt_audio]
|
| 312 |
+
prompt_text_list = [spk1_prompt_text, spk2_prompt_text]
|
| 313 |
+
use_prompt_cot = spk1_prompt_cot_text.strip()!="" or spk2_prompt_cot_text.strip()!=""
|
| 314 |
+
prompt_cot_text_list = [spk1_prompt_cot_text, spk2_prompt_cot_text]
|
| 315 |
+
data = process_single(
|
| 316 |
+
target_text_list,
|
| 317 |
+
prompt_wav_list,
|
| 318 |
+
prompt_text_list,
|
| 319 |
+
use_prompt_cot,
|
| 320 |
+
prompt_cot_text_list,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
results_dict = model.forward_longform(
|
| 324 |
+
**data
|
| 325 |
+
)
|
| 326 |
+
target_audio = None
|
| 327 |
+
for i in range(len(results_dict['generated_wavs'])):
|
| 328 |
+
if target_audio is None:
|
| 329 |
+
target_audio = results_dict['generated_wavs'][i]
|
| 330 |
+
else:
|
| 331 |
+
target_audio = torch.concat([target_audio, results_dict['generated_wavs'][i]], axis=1)
|
| 332 |
+
return (24000, target_audio.cpu().squeeze(0).numpy())
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def render_interface() -> gr.Blocks:
|
| 336 |
+
with gr.Blocks(title="SoulX-Podcast", theme=gr.themes.Default()) as page:
|
| 337 |
+
# ======================== UI ========================
|
| 338 |
+
with gr.Row():
|
| 339 |
+
lang_choice = gr.Radio(
|
| 340 |
+
choices=["中文", "English"],
|
| 341 |
+
value="中文",
|
| 342 |
+
label="Display Language/显示语言",
|
| 343 |
+
type="index",
|
| 344 |
+
interactive=True,
|
| 345 |
+
scale=3,
|
| 346 |
+
)
|
| 347 |
+
seed_input = gr.Number(
|
| 348 |
+
label="Seed (种子)",
|
| 349 |
+
value=1988,
|
| 350 |
+
step=1,
|
| 351 |
+
interactive=True,
|
| 352 |
+
scale=1,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
with gr.Row():
|
| 356 |
+
# ==== Speaker1 Prompt ====
|
| 357 |
+
with gr.Column(scale=1):
|
| 358 |
+
with gr.Group(visible=True) as spk1_prompt_group:
|
| 359 |
+
spk1_prompt_audio = gr.Audio(
|
| 360 |
+
label=i18n("spk1_prompt_audio_label"),
|
| 361 |
+
type="filepath",
|
| 362 |
+
editable=False,
|
| 363 |
+
interactive=True,
|
| 364 |
+
)
|
| 365 |
+
spk1_prompt_text = gr.Textbox(
|
| 366 |
+
label=i18n("spk1_prompt_text_label"),
|
| 367 |
+
placeholder=i18n("spk1_prompt_text_placeholder"),
|
| 368 |
+
lines=3,
|
| 369 |
+
)
|
| 370 |
+
spk1_prompt_cot_text = gr.Textbox(
|
| 371 |
+
label=i18n("spk1_prompt_cot_text_label"),
|
| 372 |
+
placeholder=i18n("spk1_prompt_cot_text_placeholder"),
|
| 373 |
+
value="",
|
| 374 |
+
lines=3,
|
| 375 |
+
)
|
| 376 |
+
# ==== Speaker2 Prompt ====
|
| 377 |
+
with gr.Column(scale=1, visible=True):
|
| 378 |
+
with gr.Group(visible=True) as spk2_prompt_group:
|
| 379 |
+
spk2_prompt_audio = gr.Audio(
|
| 380 |
+
label=i18n("spk2_prompt_audio_label"),
|
| 381 |
+
type="filepath",
|
| 382 |
+
editable=False,
|
| 383 |
+
interactive=True,
|
| 384 |
+
)
|
| 385 |
+
spk2_prompt_text = gr.Textbox(
|
| 386 |
+
label=i18n("spk2_prompt_text_label"),
|
| 387 |
+
placeholder=i18n("spk2_prompt_text_placeholder"),
|
| 388 |
+
lines=3,
|
| 389 |
+
)
|
| 390 |
+
spk2_prompt_cot_text = gr.Textbox(
|
| 391 |
+
label=i18n("spk2_prompt_cot_text_label"),
|
| 392 |
+
placeholder=i18n("spk2_prompt_cot_text_placeholder"),
|
| 393 |
+
value="",
|
| 394 |
+
lines=3,
|
| 395 |
+
)
|
| 396 |
+
# ==== Text input ====
|
| 397 |
+
with gr.Column(scale=2):
|
| 398 |
+
with gr.Row():
|
| 399 |
+
dialogue_text_input = gr.Textbox(
|
| 400 |
+
label=i18n("dialogue_text_input_label"),
|
| 401 |
+
placeholder=i18n("dialogue_text_input_placeholder"),
|
| 402 |
+
lines=18,
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
# Generate button
|
| 406 |
+
with gr.Row():
|
| 407 |
+
generate_btn = gr.Button(
|
| 408 |
+
value=i18n("generate_btn_label"),
|
| 409 |
+
variant="primary",
|
| 410 |
+
scale=3,
|
| 411 |
+
size="lg",
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
# Long output audio
|
| 415 |
+
generate_audio = gr.Audio(
|
| 416 |
+
label=i18n("generated_audio_label"),
|
| 417 |
+
interactive=False,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
with gr.Row():
|
| 421 |
+
inputs_for_examples = [
|
| 422 |
+
spk1_prompt_audio,
|
| 423 |
+
spk1_prompt_text,
|
| 424 |
+
spk1_prompt_cot_text,
|
| 425 |
+
spk2_prompt_audio,
|
| 426 |
+
spk2_prompt_text,
|
| 427 |
+
spk2_prompt_cot_text,
|
| 428 |
+
dialogue_text_input,
|
| 429 |
+
]
|
| 430 |
+
|
| 431 |
+
example_headers = [
|
| 432 |
+
"S1 音频", "S1 文本", "S1 COT",
|
| 433 |
+
"S2 音频", "S2 文本", "S2 COT",
|
| 434 |
+
"对话内容"
|
| 435 |
+
]
|
| 436 |
+
|
| 437 |
+
gr.Examples(
|
| 438 |
+
examples=EXAMPLES_LIST,
|
| 439 |
+
inputs=inputs_for_examples,
|
| 440 |
+
label="播客模板示例 (点击加载)",
|
| 441 |
+
examples_per_page=5,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# ======================== Action ========================
|
| 445 |
+
def _change_component_language(lang):
|
| 446 |
+
global global_lang
|
| 447 |
+
global_lang = ["zh", "en"][lang]
|
| 448 |
+
return [
|
| 449 |
+
|
| 450 |
+
# spk1_prompt_{audio,text,prompt_cot_text}
|
| 451 |
+
gr.update(label=i18n("spk1_prompt_audio_label")),
|
| 452 |
+
gr.update(
|
| 453 |
+
label=i18n("spk1_prompt_text_label"),
|
| 454 |
+
placeholder=i18n("spk1_prompt_text_placeholder"),
|
| 455 |
+
),
|
| 456 |
+
gr.update(
|
| 457 |
+
label=i18n("spk1_prompt_cot_text_label"),
|
| 458 |
+
placeholder=i18n("spk1_prompt_cot_text_placeholder"),
|
| 459 |
+
),
|
| 460 |
+
# spk2_prompt_{audio,text}
|
| 461 |
+
gr.update(label=i18n("spk2_prompt_audio_label")),
|
| 462 |
+
gr.update(
|
| 463 |
+
label=i18n("spk2_prompt_text_label"),
|
| 464 |
+
placeholder=i18n("spk2_prompt_text_placeholder"),
|
| 465 |
+
),
|
| 466 |
+
gr.update(
|
| 467 |
+
label=i18n("spk2_prompt_cot_text_label"),
|
| 468 |
+
placeholder=i18n("spk2_prompt_cot_text_placeholder"),
|
| 469 |
+
),
|
| 470 |
+
# dialogue_text_input
|
| 471 |
+
gr.update(
|
| 472 |
+
label=i18n("dialogue_text_input_label"),
|
| 473 |
+
placeholder=i18n("dialogue_text_input_placeholder"),
|
| 474 |
+
),
|
| 475 |
+
# generate_btn
|
| 476 |
+
gr.update(value=i18n("generate_btn_label")),
|
| 477 |
+
# generate_audio
|
| 478 |
+
gr.update(label=i18n("generated_audio_label")),
|
| 479 |
+
]
|
| 480 |
+
|
| 481 |
+
lang_choice.change(
|
| 482 |
+
fn=_change_component_language,
|
| 483 |
+
inputs=[lang_choice],
|
| 484 |
+
outputs=[
|
| 485 |
+
spk1_prompt_audio,
|
| 486 |
+
spk1_prompt_text,
|
| 487 |
+
spk1_prompt_cot_text,
|
| 488 |
+
spk2_prompt_audio,
|
| 489 |
+
spk2_prompt_text,
|
| 490 |
+
spk2_prompt_cot_text,
|
| 491 |
+
dialogue_text_input,
|
| 492 |
+
generate_btn,
|
| 493 |
+
generate_audio,
|
| 494 |
+
],
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
# Generate button click Action
|
| 498 |
+
generate_btn.click(
|
| 499 |
+
fn=dialogue_synthesis_function,
|
| 500 |
+
inputs=[
|
| 501 |
+
dialogue_text_input,
|
| 502 |
+
spk1_prompt_text,
|
| 503 |
+
spk1_prompt_audio,
|
| 504 |
+
spk1_prompt_cot_text,
|
| 505 |
+
spk2_prompt_text,
|
| 506 |
+
spk2_prompt_audio,
|
| 507 |
+
spk2_prompt_cot_text,
|
| 508 |
+
seed_input,
|
| 509 |
+
],
|
| 510 |
+
outputs=[generate_audio],
|
| 511 |
+
)
|
| 512 |
+
return page
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
# ================================================
|
| 516 |
+
# Options
|
| 517 |
+
# ================================================
|
| 518 |
+
def get_args():
|
| 519 |
+
parser = ArgumentParser()
|
| 520 |
+
parser.add_argument('--model_path',
|
| 521 |
+
required=True,
|
| 522 |
+
type=str,
|
| 523 |
+
help='model path')
|
| 524 |
+
parser.add_argument('--llm_engine',
|
| 525 |
+
type=str,
|
| 526 |
+
default="hf",
|
| 527 |
+
help='model execute engine')
|
| 528 |
+
parser.add_argument('--fp16_flow',
|
| 529 |
+
action='store_true',
|
| 530 |
+
help='enable fp16 flow')
|
| 531 |
+
parser.add_argument('--seed',
|
| 532 |
+
type=int,
|
| 533 |
+
default=1988,
|
| 534 |
+
help='random seed for generation')
|
| 535 |
+
args = parser.parse_args()
|
| 536 |
+
return args
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
if __name__ == "__main__":
|
| 540 |
+
args = get_args()
|
| 541 |
+
|
| 542 |
+
# Initiate model
|
| 543 |
+
hf_config = SoulXPodcastLLMConfig.from_initial_and_json(
|
| 544 |
+
initial_values={"fp16_flow": args.fp16_flow},
|
| 545 |
+
json_file=f"{args.model_path}/soulxpodcast_config.json")
|
| 546 |
+
|
| 547 |
+
# Compatible with the absence of a VLLM installation
|
| 548 |
+
llm_engine = args.llm_engine
|
| 549 |
+
if llm_engine == "vllm":
|
| 550 |
+
if not importlib.util.find_spec("vllm"):
|
| 551 |
+
llm_engine = "hf"
|
| 552 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
| 553 |
+
tqdm.write(f"[{timestamp}] - [WARNING]: No install VLLM, switch to hf engine.")
|
| 554 |
+
config = Config(model=args.model_path, enforce_eager=True, llm_engine=llm_engine,
|
| 555 |
+
hf_config=hf_config)
|
| 556 |
+
|
| 557 |
+
torch.manual_seed(args.seed)
|
| 558 |
+
np.random.seed(args.seed)
|
| 559 |
+
random.seed(args.seed)
|
| 560 |
+
|
| 561 |
+
initiate_model(config)
|
| 562 |
+
print("[INFO] SoulX-Podcast loaded")
|
| 563 |
+
# UI
|
| 564 |
+
page = render_interface()
|
| 565 |
+
page.queue()
|
| 566 |
+
page.launch()
|
| 567 |
+
# page.launch(share=False)
|
gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
librosa
|
| 2 |
+
numpy
|
| 3 |
+
scipy
|
| 4 |
+
s3tokenizer
|
| 5 |
+
diffusers
|
| 6 |
+
torch==2.7.1
|
| 7 |
+
torchaudio==2.7.1
|
| 8 |
+
triton>=3.0.0
|
| 9 |
+
transformers==4.57.1
|
| 10 |
+
accelerate==1.10.1
|
| 11 |
+
onnxruntime
|
| 12 |
+
onnxruntime-gpu
|
| 13 |
+
einops
|
| 14 |
+
gradio
|
soulxpodcast/__init__.py
ADDED
|
File without changes
|
soulxpodcast/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (177 Bytes). View file
|
|
|
soulxpodcast/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (165 Bytes). View file
|
|
|
soulxpodcast/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (9.22 kB). View file
|
|
|
soulxpodcast/__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (7.47 kB). View file
|
|
|
soulxpodcast/cli/__pycache__/soulxpodcast.cpython-311.pyc
ADDED
|
Binary file (15.2 kB). View file
|
|
|
soulxpodcast/cli/__pycache__/soulxpodcast.cpython-312.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
soulxpodcast/cli/engine_test.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import sys
|
| 6 |
+
from glob import glob
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from dataclasses import fields, asdict
|
| 12 |
+
|
| 13 |
+
from vllm import LLM
|
| 14 |
+
from vllm.inputs import TokensPrompt as TokensPrompt
|
| 15 |
+
from vllm import SamplingParams
|
| 16 |
+
|
| 17 |
+
def set_all_random_seed(seed):
|
| 18 |
+
import random
|
| 19 |
+
import numpy as np
|
| 20 |
+
import os
|
| 21 |
+
random.seed(seed)
|
| 22 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
| 23 |
+
np.random.seed(seed)
|
| 24 |
+
torch.manual_seed(seed)
|
| 25 |
+
torch.cuda.manual_seed_all(seed)
|
| 26 |
+
torch.backends.cudnn.deterministic = True
|
| 27 |
+
torch.backends.cudnn.benchmark = False
|
| 28 |
+
|
| 29 |
+
def get_args():
|
| 30 |
+
parser = argparse.ArgumentParser(description='FlashCosyVoice')
|
| 31 |
+
parser.add_argument('--model_path',
|
| 32 |
+
required=True,
|
| 33 |
+
type=str,
|
| 34 |
+
help='model path')
|
| 35 |
+
parser.add_argument('--seed',
|
| 36 |
+
type=int,
|
| 37 |
+
default=1986,
|
| 38 |
+
help='random seed for generation')
|
| 39 |
+
args = parser.parse_args()
|
| 40 |
+
return args
|
| 41 |
+
|
| 42 |
+
def main():
|
| 43 |
+
args = get_args()
|
| 44 |
+
|
| 45 |
+
os.environ["VLLM_USE_V1"] = "0"
|
| 46 |
+
sampling_params = SamplingParams(temperature=0.6, repetition_penalty=1.25, top_k=100, top_p=0.9, max_tokens=3000, stop_token_ids=[153478], use_ras=True, win_size=25, tau_r=0.2)
|
| 47 |
+
|
| 48 |
+
llm = LLM(model=args.model_path, enforce_eager=True, dtype="bfloat16", )
|
| 49 |
+
input = [152477, 151674, 151712, 152596, 120432, 102115, 101417, 121407, 3837, 113562, 108896, 103554, 34187, 99744, 104283, 3837, 104074, 18830, 101417, 63109, 29826, 103924, 11319, 151713, 153477, 157342, 157987, 158071, 157264, 153912, 154998, 159371, 159391, 159562, 158917, 158870, 159680, 159672, 157485, 155058, 155207, 153630, 153846, 158058, 153814, 153841, 158204, 154543, 158928, 154055, 155761, 155026, 155298, 155061, 155884, 159856, 155815, 158542, 156292, 154332, 155055, 153606, 159474, 157971, 158055, 158607, 158193, 155035, 155038, 159420, 159643, 157699, 155778, 154224, 158652, 158461, 158551, 156465, 157728, 155541, 155307, 155297, 157894, 157894, 157813, 157570, 159511, 159616, 158914, 158833, 158935, 155045, 155770, 157962, 159491, 157384, 157398, 157077, 158859, 158693, 159419, 159565, 158796, 157484, 157248, 158029, 158348, 158271, 156337, 154890, 154638, 155268, 159637, 158827, 159646, 158179, 158159, 155244, 155280, 155283, 154488, 156999, 156279, 157007, 157891, 160081, 160000, 153478, 151675, 151712, 152596, 18830, 101417, 106909, 56568, 104138, 36587, 3837, 100132, 107133, 117320, 3837, 105146, 53153, 102155, 117927, 1773, 151713, 153477, 155303, 157049, 159236, 159344, 158936, 158693, 159425, 159424, 158609, 158266, 159972, 159969, 160052, 160055, 159971, 158375, 159670, 159427, 159428, 155783, 159321, 160053, 158284, 158132, 156714, 157470, 155283, 155286, 155167, 157324, 159511, 154330, 155635, 159613, 158481, 156348, 154160, 155572, 158516, 158326, 154645, 155619, 155457, 157343, 158882, 158155, 154777, 154249, 158695, 158675, 155055, 154332, 155079, 159004, 158510, 158591, 158593, 157151, 158697, 158688, 158595, 158577, 158576, 157814, 157894, 157894, 157813, 157813, 157813, 157813, 157732, 157240, 157483, 155296, 155626, 155761, 153659, 154194, 155052, 157968, 159421, 158032, 159320, 159247, 159015, 157789, 159974, 159977, 159986, 154641, 154392, 156328, 156094, 154413, 157728, 159509, 153733, 155680, 155860, 154889, 155618, 155609, 156581, 158191, 156949, 157482, 158691, 158694, 160100, 159125, 159772, 155627, 155626, 157489, 159514, 158785, 156598, 154411, 155626, 160054, 160130, 160051, 160066, 159337, 158365, 158194, 154783, 154867, 158372, 158615, 158616, 155943, 157263, 157587, 155425, 155380, 157321, 155080, 155734, 155029, 158690, 154426, 157259, 155847, 158961, 159988, 157069, 156513, 155541, 154569, 155316, 155403, 157427, 158238, 158588, 158668, 155935, 159970, 160050, 159725, 159778, 159967, 155369, 155370, 156828, 158446, 155005, 154246, 156433, 158539, 156455, 154308, 153738, 153990, 153909, 154152, 156429, 158426, 158452, 158659, 157085, 157813, 157813, 157813, 153478, 151674, 151712, 152596, 50930, 100014, 36587, 112491, 54039, 100391, 107512, 113524, 9370, 115639, 100930, 34718, 104038, 26940, 110602, 17447, 99469, 28029, 25067, 69249, 109624, 116542, 3837, 108243, 82166, 99172, 58364, 52129, 104207, 9370, 22697, 104097, 99882, 99615, 70074, 3837, 120432, 18830, 101417, 110721, 11319, 151713, 153477, 154327, 153757, 159436, 157245, 153603, 156051, 158027, 158273, 154133, 153918, 157908, 159372, 158681, 158163, 157428, 159572, 159886, 155049, 154305, 158538, 153973, 153595, 153676, 159430, 155060, 154575, 156006, 160129, 160138, 158666, 160017, 155156, 155437, 157459, 154713, 154962, 157239, 156016, 158272, 156688, 158384, 158541, 154218, 156186, 159262, 158614, 158535, 155203, 159574, 157316, 159669, 159657, 155284, 157300, 157456, 157159, 154906, 155781, 154720, 153652, 157183, 159422, 155206, 158190, 158082, 157996, 159239, 158513, 156169, 154314, 155727, 159124, 155689, 158078, 157247, 155304, 153900, 154390, 159975, 159726, 159968, 158257, 158203, 155512, 158056, 158479, 158085, 156519, 154329, 154364, 155059, 159433, 159433, 158704, 155140, 155707, 153649, 156516, 155057, 154656, 154079, 155191, 155775, 157239, 156669, 154951, 158517, 158030, 158245, 158299, 154123, 154728, 154857, 157830, 159853, 158885, 158425, 158072, 157749, 155157, 159535, 159619, 158836, 156597, 155540, 155550, 154101, 156277, 156349, 156124, 158349, 153651, 153898, 156835, 156445, 159397, 157232, 155128, 157158, 156023, 159500, 155871, 154557, 157319, 159577, 158938, 158158, 158629, 159287, 157833, 157860, 157887, 159426, 156355, 158623, 154287, 155544, 157727, 159512, 158038, 158514, 158533, 155790, 154812, 154830, 155542, 155056, 157246, 157243, 155053, 155620, 154813, 154000, 158455, 158621, 158678, 155051, 154177, 160154, 158689, 153976, 158366, 156021, 154233, 154161, 158002, 158270, 153890, 158781, 158639, 158642, 160132, 160133, 160144, 157967, 155694, 157644, 156024, 158135, 155784, 155078, 155491, 155599, 154326, 155543, 154821, 156774, 159533, 158480, 158417, 158054, 157258, 155787, 153792, 159538, 158890, 158809, 158809, 158107, 156577, 155133, 154404, 155446, 158515, 157076, 158453, 159040, 158392, 157669, 159245, 158524, 153708, 155546, 154899, 158660, 160096, 158627, 159431, 157490, 154197, 156096, 158887, 156733, 154434, 154119, 160013, 159043, 155573, 155545, 157327, 159433, 156517, 155383, 155401, 155072, 153639, 156294, 156888, 158664, 158663, 159562, 155920, 157242, 157726, 157241, 154517, 155356, 157708, 160151, 160153, 159418, 159101, 159263, 158534, 158364, 158004, 156524, 157567, 157894, 157894, 153478, 151675, 151712, 152596, 110931, 45629, 101454, 17447, 40814, 99164, 100753, 41683, 9370, 2073, 119176, 103162, 102781, 854, 100074, 99615, 47872, 6313, 115639, 112491, 99615, 81668, 104462, 99319, 100307, 102561, 99964, 101182, 99319, 108375, 100074, 3837, 99557, 99601, 104631, 108439, 100372, 100369, 99375, 99261, 6313, 151713, 153477, 155328, 159429, 159432, 153612, 157265, 155322, 155653, 159246, 156346, 153981, 155055, 155539, 155541, 154270, 155677, 160105, 160075, 157207, 160139, 158897, 158267, 158518, 158052, 156756, 154092, 155559, 155318, 155554, 155299, 155623, 155302, 159433, 159541, 157354, 155411, 154843, 158351, 158362, 159343, 156105, 154397, 158521, 154965, 154221, 156089, 157840, 159325, 159319, 160067, 160070, 159340, 159094, 159244, 158428, 156352, 158458, 159032, 158134, 158000, 157261, 155158, 153668, 153597, 153867, 154110, 154838, 155569, 156997, 157000, 154898, 155636, 157941, 155672, 155673, 155754, 160127, 160126, 158692, 158858, 156319, 157048, 157075, 154242, 154359, 153653, 155077, 155071, 158700, 159104, 158374, 158383, 158232, 157017, 157503, 157506, 155319, 155399, 155644, 155545, 155053, 157243, 159457, 157270, 155083, 157786, 159243, 155782, 158941, 159194, 159752, 153849, 155562, 155643, 155722, 154991, 159915, 155058, 154440, 156501, 158209, 155518, 153661, 158696, 158200, 158861, 157566, 155622, 155706, 155733, 155760, 155624, 154814, 157810, 157813, 157813, 157813, 157813, 157246, 159508, 159454, 157267, 155326, 158239, 158590, 159322, 159403, 158599, 158680, 156482, 155646, 157509, 155482, 154135, 159510, 159435, 153954, 154070, 153598, 155863, 159670, 157250, 154332, 154143, 158454, 157861, 160135, 156503, 158600, 159935, 159773, 159693, 159774, 157909, 159267, 155055, 154602, 154062, 158207, 156364, 156436, 156481, 156510, 154839, 157394, 157557, 159023, 159103, 156036, 155640, 155400, 155321, 155563, 155626, 155545, 159433, 159460, 158731, 154357, 155464, 155515, 157481, 159264, 153999, 153711, 153747, 158561, 159290, 156310, 158210, 158617, 158543, 158679, 154309, 155758, 154323, 153975, 159581, 158214, 159671, 154578, 155565, 155645, 154168, 154264, 160155, 160074, 160073, 159423, 155047, 155707, 155269, 157157, 156347, 153657, 154356, 153648, 159209, 157021, 158506, 158587, 158416, 155965, 159532, 157346, 159983, 155809, 158212, 159193, 159753, 157008, 154587, 155561, 155551, 157729, 157813, 157813, 157813, 157813, 157489, 159511, 159457, 156541, 155461, 154408, 153806, 153744, 156590, 159046, 158254, 158108, 158354, 158353, 156177, 155061, 153626, 159448, 159127, 159976, 158079, 158706, 155844, 154624, 159240, 156511, 159913, 154815, 154818, 156492, 158686, 158275, 159077, 160049, 159256, 157800, 157798, 158596, 159157, 159346, 158641, 155780, 155697, 158759, 158110, 154159, 154437, 155547, 154089, 156384, 158687, 156500, 157084, 159184, 159187, 154817, 155726, 157940, 157939, 158684, 158131, 158057, 153644, 157507, 155564, 158434, 157408, 156676, 158156, 155895, 155126, 157395, 153671, 156420, 155868, 156368, 160019, 153850, 156171, 158685, 157886, 156509, 159132, 155274, 155277, 155196, 155168, 155141, 155138, 155626, 153478, 151674, 151712, 152596, 102177, 73670, 99877, 2073, 102274, 108243, 854, 99257, 100896, 3837, 100345, 26940, 107513, 99815, 85361, 23656, 25067, 105636, 3837, 58364, 52129, 2073, 100917, 108471, 101187, 99243, 100859, 854, 2073, 106957, 99476, 23305, 44729, 854, 100001, 100175, 107600, 9370, 119305, 100399, 70074, 6313, 151713, 153477, 157543, 155855, 158784, 158780, 158703, 155796, 156531, 155402, 153838, 156187, 156287, 158109, 158373, 156321, 154403, 157057, 156088, 160152, 159421, 155290, 157309, 157315, 157401, 158940, 156753, 159485, 159643, 158920, 154474, 154650, 154910, 159770, 158318, 158507, 154134, 153882, 153618, 159449, 156940, 154084, 156106, 159196, 159350, 158624, 154305, 154322, 154978, 154267, 160047, 159398, 155050, 154897, 154291, 155043, 154917, 159854, 156859, 158278, 157236, 154116, 154842, 160048, 157942, 157731, 159186, 158555, 159286, 157822, 160081, 157894, 157894, 157894, 157894, 157813, 157813, 157813, 154570, 153802, 153624, 158075, 160099, 159091, 159911, 159912, 156996, 153674, 155410, 154141, 153903, 154278, 157185, 159121, 158884, 157066, 158281, 154971, 153669, 159204, 159367, 159850, 156449, 154086, 154833, 154188, 156365, 159352, 158633, 159833, 159832, 159589, 157447, 154451, 157238, 157965, 153835, 154870, 158274, 154888, 155376, 155605, 156817, 153627, 159513, 158868, 156141, 155331, 155384, 156760, 159433, 159433, 159433, 159433, 155137, 157837, 160102, 160129, 157954, 155700, 154968, 159588, 159183, 158189, 158783, 155799, 153864, 156068, 158188, 159163, 156967, 158192, 157976, 156536, 155320, 159253, 154647, 153873, 153603, 158229, 156320, 157039, 158444, 158860, 158546, 157104, 155725, 154298, 159593, 156114, 153819, 154384, 157405, 159437, 159995, 154104, 155724, 155716, 155755, 154646, 154863, 154374, 157746, 159045, 158291, 159650, 157444, 159181, 158202, 153600, 155117, 157313, 157393, 155811, 159284, 160016, 159804, 159910, 158197, 158137, 155795, 157262, 155347, 159980, 157556, 156585, 153663, 155114, 156518, 158704, 155788, 155221, 155113, 156739, 153789, 155852, 159000, 154132, 156087, 158081, 155194, 154621, 156025, 154081, 155613, 154137, 158186, 158996, 159220, 158286, 153894, 153654, 158026, 158597, 156184, 158619, 158651, 159409, 155164, 159643, 156703, 155210, 157314, 157977, 156339, 154862, 154861, 154727, 155568, 155574, 155007, 158688, 156280, 158536, 158581, 158402, 156651, 159643, 154471, 154677, 156288, 159044, 155555, 157894, 157894, 157813, 153478, 151675, 151712, 152596, 99936, 30534, 104609, 106594, 99882, 99615, 70074, 9370, 101485, 99499, 57566, 6313, 102245, 101325, 99164, 45861, 102504, 9370, 2073, 100268, 93, 102634, 44729, 93, 121407, 93, 119921, 93, 99800, 101953, 93, 33590, 99601, 49187, 36407, 100132, 104666, 101941, 6313, 151713, 153477, 154258, 156489, 155054, 154331, 154349, 159775, 157831, 158516, 156148, 158443, 158165, 155817, 153636, 155074, 155419, 155329, 159433, 156517, 154816, 159235, 156015, 154896, 154230, 154948, 158515, 157222, 154275, 155540, 155567, 159914, 155971, 158515, 158608, 160071, 157884, 157155, 154320, 155039, 157807, 156754, 155323, 157030, 158347, 156504, 154296, 157914, 157590, 157617, 157724, 159668, 158198, 158162, 158001, 156533, 159453, 157266, 155105, 155330, 157246, 155086, 154870, 158111, 156427, 155976, 157001, 154098, 154206, 158669, 159370, 157906, 159266, 157244, 153927, 153675, 158753, 159969, 157060, 153660, 155315, 159776, 154633, 158025, 157998, 156054, 156027, 153840, 154083, 154595, 155299, 157240, 154412, 154826, 157642, 157480, 159664, 158206, 155940, 155180, 155103, 155102, 155183, 155200, 159665, 157725, 155295, 155441, 155479, 155477, 155898, 158445, 158427, 158319, 159047, 157823, 157813, 157813, 157813, 157813, 157813, 157813, 157813, 157813, 155398, 159445, 157970, 153623, 158512, 156342, 158670, 158643, 160101, 159369, 155011, 157078, 159751, 157591, 155407, 154627, 156133, 158542, 154178, 154302, 153982, 158269, 158682, 158321, 159973, 158511, 158698, 159679, 159103, 158311, 159695, 155483, 158516, 155869, 156526, 157494, 154821, 154911, 155314, 155838, 158322, 158241, 158223, 158303, 158222, 155324, 155570, 155300, 155545, 155296, 157483, 157483, 155539, 155707, 157912, 157666, 156452, 158650, 158647, 158648, 158568, 158571, 156357, 154170, 154179, 154845, 154844, 155654, 155545, 155626, 157813, 157486, 159427, 157246, 155383, 154843, 158271, 156175, 158696, 155979, 156600, 153635, 159451, 155428, 159982, 159985, 159744, 159096, 158043, 158034, 158115, 158087, 158087, 158978, 157495, 157733, 157813, 157813, 157813, 157813, 157813, 157813, 155302, 159430, 159430, 157243, 155059, 155707, 153892, 158195, 159653, 159654, 157467, 157476, 157397, 159582, 154398, 158139, 158166, 158112, 155925, 156168, 156249, 154071, 154719, 155439, 155358, 155087, 155383, 157813, 157732, 157732, 157813, 157813, 157813, 157813, 157243, 159430, 156514, 155377, 157894, 154297, 158618, 159104, 158943, 158457, 156270, 156351, 154167, 154893, 155595, 155325, 155570, 155545, 155626, 157813, 157813, 157813, 157813, 157813, 157813, 157813, 155626, 159973, 159248, 158295, 159269, 158544, 156393, 158661, 158604, 158995, 160131, 156340, 158294, 159024, 159105, 158862, 159618, 159621, 159645, 159630, 159629, 159548, 157579, 159766, 157498, 157741, 157813, 157813, 157813, 157813, 157894, 155707, 157786, 158997, 157969, 158950, 159590, 157329, 153684, 154035, 155862, 155434, 153901, 156837, 156273, 156354, 156381, 158649, 158676, 158657, 158656, 159388, 160010, 157742, 157732, 157732, 157732, 157813, 157813, 157813, 157813, 157813, 157726, 159670, 159697, 157510, 157510, 157483, 157483, 157726, 157483, 157483, 157726, 157483, 157483, 157486, 157732, 160000, 160000, 160000, 160000, 160000, 159919, 159919, 159919, 159919, 159919, 159919, 157732, 157732, 159217, 158243, 158489, 153889, 154386, 154353, 156529, 157588, 156097, 158051, 157974, 153752, 155365, 157975, 156112, 153729, 155076, 157752, 157881, 156099, 158133, 155865, 153681, 153735, 154464, 155084, 155327, 155056, 157243, 157240, 157240, 157240, 157813, 160000, 160000, 157813, 157813, 157813, 157813, 157813, 157894, 155593, 158105, 158592, 156426, 154239, 156507, 157233, 158725, 158671, 158509, 157799, 157071, 155856, 153750, 153993, 153741, 153984, 154389, 155351, 155573, 155545, 155545, 155644, 157426, 156695, 158371, 158560, 158264, 157508, 157589, 155509, 156085, 156103, 155934, 154410, 155922, 158265, 159237, 157805, 157806, 156834, 156807, 158994, 159830, 159829, 158939, 156750, 153825, 153816, 153705, 154461, 155112, 155354, 155327, 155545, 157813, 157813, 153478, 151674, 151712, 152596, 110602, 17447, 99469, 99354, 75882, 29635, 18947, 2073, 70074, 85254, 101904, 23836, 33590, 99258, 104090, 111077, 18800, 101044, 106063, 18397, 115639, 119754, 46553, 3837, 45181, 99391, 22697, 9370, 101780, 17523, 99882, 99615, 104188, 99530, 22697, 9370, 120728, 121909, 99293, 115807, 101182, 6313, 151713, 153477, 156815, 158031, 156031, 154620, 154128, 159365, 157423, 158399, 158173, 158960, 157260, 159535, 156730, 157323, 155541, 154824, 156290, 156268, 156367, 158268, 153732, 159182, 158447, 159131, 159591, 157404, 155217, 154542, 155488, 153760, 155139, 155061, 158409, 158201, 158914, 158836, 155993, 156081, 154883, 154126, 156414, 153678, 156542, 159643, 158839, 153743, 154191, 155558, 159534, 157777, 159010, 159345, 158096, 159933, 155481, 155353, 158380, 156283, 159316, 158668, 158606, 158220, 155061, 154341, 155159, 156301, 159164, 159622, 155176, 155538, 154566, 158211, 155115, 155627, 158947, 159676, 159433, 159433, 159433, 155222, 155392, 155294, 155136, 157002, 155802, 156374, 157156, 158932, 158683, 158613, 154903, 153919, 156034, 158529, 156324, 156297, 155919, 155275, 153895, 159516, 157491, 154556, 155530, 155046, 155052, 157882, 157951, 158187, 158435, 158138, 159753, 157809, 159834, 158891, 158806, 158915, 155902, 156162, 156405, 155767, 156009, 158355, 156411, 154080, 155500, 159562, 157915, 154174, 155306, 155058, 156123, 159155, 159884, 153828, 153948, 156392, 158620, 158448, 156267, 158039, 158672, 158356, 156498, 155025, 159368, 155443, 158024, 159515, 156762, 153823, 155524, 158616, 156060, 153621, 155617, 156823, 154881, 153972, 158798, 159041, 155582, 155626, 157894, 157894, 157813, 159757, 159430, 159430, 159430, 159430, 155383, 157759, 158350, 156358, 156483, 157101, 155480, 157795, 159242, 158106, 159270, 158625, 158674, 159167, 159643, 158914, 156638, 157475, 155127, 155814, 153742, 156082, 159265, 158665, 159424, 159425, 156274, 156454, 160154, 159176, 154150, 158290, 154407, 157734, 158630, 160093, 158385, 158787, 156033, 153736, 153628, 154096, 155349, 158718, 157585, 155837, 159996, 157083, 156462, 155937, 156428, 155590, 155591, 154161, 154629, 158865, 158930, 158914, 156655, 159662, 153642, 155892, 154957, 154243, 156844, 158184, 156014, 156584, 158436, 158696, 158282, 159081, 158488, 156348, 155261, 154722, 156492, 158565, 156506, 154987, 154294, 160155, 159424, 155691, 155708, 157813, 153478, 151675, 151712, 152596, 102177, 99360, 91777, 100569, 17340, 101376, 102073, 22697, 70074, 104387, 12857, 74393, 41505, 120965, 101240, 120965, 102565, 97907, 102138, 34718, 70074, 91956, 99662, 99318, 121657, 44729, 97907, 99318, 100893, 70074, 3837, 100132, 75606, 110261, 104754, 6313, 151713, 153477]
|
| 50 |
+
outputs = llm.generate(TokensPrompt(prompt_token_ids=input),
|
| 51 |
+
sampling_params,
|
| 52 |
+
use_tqdm=False)[0].outputs[0].token_ids
|
| 53 |
+
print(outputs)
|
| 54 |
+
# files = glob(f"{args.data_list}/*_result.json")
|
| 55 |
+
# files.sort()
|
| 56 |
+
# for file in files:
|
| 57 |
+
# with open(file) as fin:
|
| 58 |
+
# test_sets = json.load(fin)
|
| 59 |
+
# for test_set in test_sets:
|
| 60 |
+
# input = test_set["input"]
|
| 61 |
+
# set_all_random_seed(args.seed)
|
| 62 |
+
# llm_outputs = model.llm.generate(input, sampling_params)['token_ids']
|
| 63 |
+
# set_all_random_seed(args.seed)
|
| 64 |
+
# import pdb;pdb.set_trace()
|
| 65 |
+
# outputs = llm.generate(TokensPrompt(prompt_token_ids=input),
|
| 66 |
+
# VllmSamplingParams(**asdict(sampling_params)),
|
| 67 |
+
# use_tqdm=False)[0].outputs[0].token_ids
|
| 68 |
+
# print(llm_outputs)
|
| 69 |
+
# print(outputs)
|
| 70 |
+
# print("=========")
|
| 71 |
+
# import pdb;pdb.set_trace()
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
main()
|
soulxpodcast/cli/soulxpodcast.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import sys
|
| 6 |
+
import time
|
| 7 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
import importlib.util
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import onnxruntime
|
| 13 |
+
import s3tokenizer
|
| 14 |
+
import torch
|
| 15 |
+
import torchaudio
|
| 16 |
+
import torchaudio.compliance.kaldi as kaldi
|
| 17 |
+
from torch.utils.data import DataLoader, Dataset, SequentialSampler
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
|
| 20 |
+
from soulxpodcast.config import Config, SoulXPodcastLLMConfig, SamplingParams
|
| 21 |
+
from soulxpodcast.models.soulxpodcast import SoulXPodcast
|
| 22 |
+
from soulxpodcast.utils.dataloader import PodcastDataset
|
| 23 |
+
from soulxpodcast.utils.audio import mel_spectrogram
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def set_all_random_seed(seed):
|
| 27 |
+
random.seed(seed)
|
| 28 |
+
np.random.seed(seed)
|
| 29 |
+
torch.manual_seed(seed)
|
| 30 |
+
torch.cuda.manual_seed_all(seed)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def save_file_async(
|
| 34 |
+
wav, uttid,
|
| 35 |
+
info,
|
| 36 |
+
is_sub,
|
| 37 |
+
):
|
| 38 |
+
"""Save audio asynchronously."""
|
| 39 |
+
try:
|
| 40 |
+
parentdir = f"{os.path.dirname(info['wav'])}"
|
| 41 |
+
basename = os.path.basename(info["wav"]).split(".")[0]
|
| 42 |
+
if is_sub:
|
| 43 |
+
parentdir = f"{parentdir}/individual_clips"
|
| 44 |
+
os.makedirs(parentdir, exist_ok=True)
|
| 45 |
+
if wav is not None:
|
| 46 |
+
wav = wav.cpu()
|
| 47 |
+
torchaudio.save(f'{parentdir}/{uttid}.wav', wav, 24000)
|
| 48 |
+
duration = wav.shape[-1] / 24000.0
|
| 49 |
+
else:
|
| 50 |
+
duration = 0.0
|
| 51 |
+
if not is_sub:
|
| 52 |
+
with open(f"{parentdir}/{basename}.json", "w") as f:
|
| 53 |
+
json.dump(info, f, ensure_ascii=False, indent=4)
|
| 54 |
+
return duration
|
| 55 |
+
except Exception as e:
|
| 56 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
| 57 |
+
tqdm.write(f"[{timestamp}] - [ERROR] - Error saving audio {info.get('key', 'unknown')}: {e}")
|
| 58 |
+
return 0.0
|
| 59 |
+
|
| 60 |
+
def collate_fn(batch):
|
| 61 |
+
assert len(batch) == 1
|
| 62 |
+
data = batch[0]
|
| 63 |
+
|
| 64 |
+
# prepare prompt mels for llm, spk_emb + prompt mel for flow;
|
| 65 |
+
prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(data["log_mel"]) # [B, num_mels=128, T]
|
| 66 |
+
spk_emb_for_flow = torch.tensor(data["spk_emb"])
|
| 67 |
+
prompt_mels_for_flow = data["mel"]
|
| 68 |
+
|
| 69 |
+
# prepare text + spk for llm;
|
| 70 |
+
text_tokens_for_llm = data["text_tokens"]
|
| 71 |
+
prompt_text_tokens_for_llm = data["prompt_text_tokens"]
|
| 72 |
+
spk_ids = data["spks_list"]
|
| 73 |
+
sampling_params = SamplingParams()
|
| 74 |
+
infos = [data["info"]]
|
| 75 |
+
|
| 76 |
+
processed_data = {
|
| 77 |
+
"prompt_mels_for_llm": prompt_mels_for_llm,
|
| 78 |
+
"prompt_mels_lens_for_llm": prompt_mels_lens_for_llm,
|
| 79 |
+
"prompt_text_tokens_for_llm": prompt_text_tokens_for_llm,
|
| 80 |
+
"text_tokens_for_llm": text_tokens_for_llm,
|
| 81 |
+
"prompt_mels_for_flow_ori": prompt_mels_for_flow,
|
| 82 |
+
"spk_emb_for_flow": spk_emb_for_flow,
|
| 83 |
+
"sampling_params": sampling_params,
|
| 84 |
+
"spk_ids": spk_ids,
|
| 85 |
+
"infos": infos,
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
if data.get("use_prompt_cot", False):
|
| 89 |
+
processed_data.update({
|
| 90 |
+
"use_prompt_cot": True,
|
| 91 |
+
"prompt_cot_text_tokens_for_llm": data["prompt_cot_text_tokens"],
|
| 92 |
+
"prompt_cot_prefix": data["prompt_cot_prefix"],
|
| 93 |
+
})
|
| 94 |
+
return processed_data
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_args():
|
| 98 |
+
parser = argparse.ArgumentParser(description='FlashCosyVoice')
|
| 99 |
+
parser.add_argument('--model_path',
|
| 100 |
+
required=True,
|
| 101 |
+
type=str,
|
| 102 |
+
help='model path')
|
| 103 |
+
parser.add_argument('--data_list',
|
| 104 |
+
required=True,
|
| 105 |
+
type=str,
|
| 106 |
+
help='data list')
|
| 107 |
+
parser.add_argument('--num_workers',
|
| 108 |
+
type=int,
|
| 109 |
+
default=4,
|
| 110 |
+
help='workers for dataloader')
|
| 111 |
+
parser.add_argument('--prefetch',
|
| 112 |
+
type=int,
|
| 113 |
+
default=5,
|
| 114 |
+
help='prefetch for dataloader')
|
| 115 |
+
parser.add_argument('--llm_engine',
|
| 116 |
+
type=str,
|
| 117 |
+
default="hf",
|
| 118 |
+
help='model execute engine')
|
| 119 |
+
parser.add_argument('--fp16_flow',
|
| 120 |
+
action='store_true',
|
| 121 |
+
help='enable fp16 flow')
|
| 122 |
+
parser.add_argument('--seed',
|
| 123 |
+
type=int,
|
| 124 |
+
default=1986,
|
| 125 |
+
help='random seed for generation')
|
| 126 |
+
parser.add_argument('--save_intermediate',
|
| 127 |
+
action='store_true',
|
| 128 |
+
help='enable save intermediate result in long form.')
|
| 129 |
+
args = parser.parse_args()
|
| 130 |
+
return args
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def main():
|
| 134 |
+
args = get_args()
|
| 135 |
+
|
| 136 |
+
assert (torch.cuda.is_available())
|
| 137 |
+
hf_config = SoulXPodcastLLMConfig.from_initial_and_json(
|
| 138 |
+
initial_values={"fp16_flow": args.fp16_flow},
|
| 139 |
+
json_file=f"{args.model_path}/soulxpodcast_config.json")
|
| 140 |
+
|
| 141 |
+
# Compatible with the absence of a VLLM installation
|
| 142 |
+
llm_engine = args.llm_engine
|
| 143 |
+
if llm_engine == "vllm":
|
| 144 |
+
if not importlib.util.find_spec("vllm"):
|
| 145 |
+
llm_engine = "hf"
|
| 146 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
| 147 |
+
tqdm.write(f"[{timestamp}] - [WARNING]: No install VLLM, switch to hf engine.")
|
| 148 |
+
|
| 149 |
+
config = Config(model=args.model_path, enforce_eager=True, llm_engine=llm_engine,
|
| 150 |
+
hf_config=hf_config)
|
| 151 |
+
model = SoulXPodcast(config)
|
| 152 |
+
|
| 153 |
+
set_all_random_seed(args.seed)
|
| 154 |
+
|
| 155 |
+
dataset = PodcastDataset(model.llm.tokenizer, args.data_list, config)
|
| 156 |
+
sampler = SequentialSampler(dataset,)
|
| 157 |
+
dataloader = DataLoader(dataset, batch_size=1, num_workers=args.num_workers, pin_memory=True,
|
| 158 |
+
sampler=sampler, shuffle=False, prefetch_factor=args.prefetch, collate_fn=collate_fn)
|
| 159 |
+
total_steps = len(dataset)
|
| 160 |
+
|
| 161 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
| 162 |
+
tqdm.write(f"[{timestamp}] - [INFO] - {args}")
|
| 163 |
+
progress_bar = tqdm(total=total_steps, desc="Processing samples", unit="wav",
|
| 164 |
+
position=0, leave=True, dynamic_ncols=True)
|
| 165 |
+
|
| 166 |
+
cpu_counts = os.cpu_count()
|
| 167 |
+
executor = ThreadPoolExecutor(max_workers=min(args.num_workers, cpu_counts // 2))
|
| 168 |
+
|
| 169 |
+
pending_futures = []
|
| 170 |
+
dataloader_iter = iter(dataloader)
|
| 171 |
+
succeed_duration = 0.01 # avoid division by zero
|
| 172 |
+
start_time = time.time()
|
| 173 |
+
estimated_total_wavs = 0
|
| 174 |
+
succeed_wavs = 0
|
| 175 |
+
failed_wavs = 0
|
| 176 |
+
last_print_time = start_time
|
| 177 |
+
|
| 178 |
+
while True:
|
| 179 |
+
try:
|
| 180 |
+
batch = next(dataloader_iter)
|
| 181 |
+
|
| 182 |
+
if len(batch['infos']) == 0:
|
| 183 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
| 184 |
+
tqdm.write(f"[{timestamp}] - [WARNING]: No valid batch found, skipping this batch...")
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
+
results_dict = model.forward_longform(**batch)
|
| 188 |
+
|
| 189 |
+
estimated_total_wavs += len(results_dict['generated_wavs'])
|
| 190 |
+
|
| 191 |
+
uttid = batch['infos'][0]['key']
|
| 192 |
+
result = None
|
| 193 |
+
for i in range(len(results_dict['generated_wavs'])):
|
| 194 |
+
is_sub = True
|
| 195 |
+
if args.save_intermediate:
|
| 196 |
+
future = executor.submit(
|
| 197 |
+
save_file_async, results_dict['generated_wavs'][i],
|
| 198 |
+
f"{uttid}_turn_{str(i).zfill(2)}", batch['infos'][0].copy(), is_sub
|
| 199 |
+
)
|
| 200 |
+
pending_futures.append(future)
|
| 201 |
+
if result is None:
|
| 202 |
+
result = results_dict['generated_wavs'][i]
|
| 203 |
+
else:
|
| 204 |
+
result = torch.concat([result, results_dict['generated_wavs'][i]], axis=1)
|
| 205 |
+
future = executor.submit(
|
| 206 |
+
save_file_async, result,
|
| 207 |
+
f"{uttid}", batch['infos'][0].copy(), False
|
| 208 |
+
)
|
| 209 |
+
pending_futures.append(future)
|
| 210 |
+
completed_futures = []
|
| 211 |
+
for future in pending_futures:
|
| 212 |
+
if future.done():
|
| 213 |
+
try:
|
| 214 |
+
duration = future.result()
|
| 215 |
+
succeed_duration += duration
|
| 216 |
+
succeed_wavs += 1
|
| 217 |
+
except Exception as e:
|
| 218 |
+
failed_wavs += 1
|
| 219 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
| 220 |
+
tqdm.write(f"[{timestamp}] - [ERROR]: Error in async save task: {e}")
|
| 221 |
+
completed_futures.append(future)
|
| 222 |
+
|
| 223 |
+
for future in completed_futures:
|
| 224 |
+
pending_futures.remove(future)
|
| 225 |
+
|
| 226 |
+
update_n = 1
|
| 227 |
+
if progress_bar.n + update_n > progress_bar.total:
|
| 228 |
+
progress_bar.update(progress_bar.total - progress_bar.n)
|
| 229 |
+
else:
|
| 230 |
+
progress_bar.update(update_n)
|
| 231 |
+
|
| 232 |
+
current_time = time.time()
|
| 233 |
+
if current_time - last_print_time >= 120:
|
| 234 |
+
elapsed_time = current_time - start_time
|
| 235 |
+
avg_duration = succeed_duration / succeed_wavs if succeed_wavs > 0 else 0
|
| 236 |
+
estimated_total_duration = avg_duration * estimated_total_wavs
|
| 237 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
| 238 |
+
tqdm.write(f"[{timestamp}] - [INFO]: Estimated total wavs: {estimated_total_wavs} ({estimated_total_wavs - succeed_wavs} pending to save), Succeed wavs: {succeed_wavs}, Failed wavs: {failed_wavs}, Estimated total duration: {estimated_total_duration:.2f}s ({estimated_total_duration / 3600:.2f} h), Elapsed time: {elapsed_time:.2f}s") # noqa
|
| 239 |
+
last_print_time = current_time
|
| 240 |
+
except StopIteration:
|
| 241 |
+
break
|
| 242 |
+
except Exception as e:
|
| 243 |
+
failed_wavs += 1
|
| 244 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
| 245 |
+
tqdm.write(f"[{timestamp}] - [ERROR]: Error in main loop: {e}")
|
| 246 |
+
continue
|
| 247 |
+
|
| 248 |
+
total_time = time.time() - start_time
|
| 249 |
+
|
| 250 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
| 251 |
+
tqdm.write(f"[{timestamp}] - [INFO] - Waiting for {len(pending_futures)} pending save tasks to complete...")
|
| 252 |
+
|
| 253 |
+
for future in pending_futures:
|
| 254 |
+
try:
|
| 255 |
+
duration = future.result(timeout=60)
|
| 256 |
+
succeed_duration += duration
|
| 257 |
+
succeed_wavs += 1
|
| 258 |
+
except Exception as e:
|
| 259 |
+
failed_wavs += 1
|
| 260 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
| 261 |
+
tqdm.write(f"[{timestamp}] - [ERROR]: Error in final async save task: {e}")
|
| 262 |
+
executor.shutdown(wait=True)
|
| 263 |
+
|
| 264 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
| 265 |
+
tqdm.write(f"[{timestamp}] - [INFO]: All async save tasks completed.")
|
| 266 |
+
progress_bar.close()
|
| 267 |
+
|
| 268 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
| 269 |
+
tqdm.write(f"[{timestamp}] - [INFO]: Final Report - Succeed wavs: {succeed_wavs}, Failed wavs: {failed_wavs}, Total duration: {succeed_duration:.2f}s ({succeed_duration / 3600:.2f} h).") # noqa
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
if __name__ == "__main__":
|
| 273 |
+
main()
|
soulxpodcast/config.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass, field, fields, is_dataclass, asdict
|
| 3 |
+
from typing import Any, Dict, List, Optional
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from transformers import AutoConfig
|
| 9 |
+
from transformers import PretrainedConfig
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class SoulXPodcastLLMConfig:
|
| 13 |
+
architectures: list[str] = field(default_factory=lambda: ["Qwen3ForCausalLM"])
|
| 14 |
+
attention_dropout: float = 0.0
|
| 15 |
+
bos_token_id: int = 151643
|
| 16 |
+
eos_token_id: int = 151675 # speech eos
|
| 17 |
+
hidden_act: str = "silu"
|
| 18 |
+
hidden_size: int = 2048
|
| 19 |
+
initializer_range: float = 0.02
|
| 20 |
+
intermediate_size: int = 6144
|
| 21 |
+
max_position_embeddings: int = 40960
|
| 22 |
+
max_window_layers: int = 28
|
| 23 |
+
model_type: str = "qwen3"
|
| 24 |
+
num_attention_heads: int = 16
|
| 25 |
+
num_hidden_layers: int = 28
|
| 26 |
+
num_key_value_heads: int = 8
|
| 27 |
+
head_dim: int = 128
|
| 28 |
+
rms_norm_eps: float = 1e-06
|
| 29 |
+
rope_scaling: dict | None = None
|
| 30 |
+
rope_theta: float = 1000000.0
|
| 31 |
+
sliding_window: int = 32768
|
| 32 |
+
tie_word_embeddings: bool = True
|
| 33 |
+
torch_dtype: str = "bfloat16"
|
| 34 |
+
transformers_version: str = "4.52.3"
|
| 35 |
+
use_cache: bool = True
|
| 36 |
+
use_sliding_window: bool = False
|
| 37 |
+
vocab_size: int = 159488 # text_vocab_size + speech_vocab_size + 2 (eos and task_id)
|
| 38 |
+
lm_head_bias: bool = False
|
| 39 |
+
qkv_bias: bool = False
|
| 40 |
+
fp16_flow: bool = False
|
| 41 |
+
speech_token_offset: int = 152927
|
| 42 |
+
|
| 43 |
+
@classmethod
|
| 44 |
+
def from_initial_and_json(
|
| 45 |
+
cls,
|
| 46 |
+
initial_values: Dict[str, Any] = None,
|
| 47 |
+
json_file: Optional[str] = None
|
| 48 |
+
):
|
| 49 |
+
"""
|
| 50 |
+
Create instance from initial values and JSON data
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
initial_values: Initial key-value dict, which will overrides all other configurations
|
| 54 |
+
json_file: JSON file path
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
SoulXPodcastLLMConfig instance
|
| 58 |
+
"""
|
| 59 |
+
# Merge all data sources
|
| 60 |
+
merged_data = {}
|
| 61 |
+
|
| 62 |
+
# 1. Load from JSON file first (lowest priority)
|
| 63 |
+
if json_file and os.path.exists(json_file):
|
| 64 |
+
file_data = cls._load_json_file(json_file)
|
| 65 |
+
merged_data.update(file_data)
|
| 66 |
+
|
| 67 |
+
# 2. Override with initial values last (highest priority)
|
| 68 |
+
if initial_values:
|
| 69 |
+
merged_data.update(initial_values)
|
| 70 |
+
|
| 71 |
+
# 3. Extract dataclass fields
|
| 72 |
+
valid_fields = {f.name for f in fields(cls)}
|
| 73 |
+
init_data = {k: v for k, v in merged_data.items() if k in valid_fields}
|
| 74 |
+
|
| 75 |
+
return cls(**init_data)
|
| 76 |
+
|
| 77 |
+
@staticmethod
|
| 78 |
+
def _load_json_file(file_path: str) -> Dict[str, Any]:
|
| 79 |
+
"""从JSON文件加载数据"""
|
| 80 |
+
path = Path(file_path)
|
| 81 |
+
if not path.exists():
|
| 82 |
+
return {}
|
| 83 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 84 |
+
return json.load(f)
|
| 85 |
+
|
| 86 |
+
class AutoPretrainedConfig(PretrainedConfig):
|
| 87 |
+
model_type = "qwen3"
|
| 88 |
+
|
| 89 |
+
def __init__(self, **kwargs):
|
| 90 |
+
# Remove non-configuration parameters
|
| 91 |
+
config_kwargs = {k: v for k, v in kwargs.items()
|
| 92 |
+
if not k.startswith('_') and k != 'self'}
|
| 93 |
+
super().__init__(**config_kwargs)
|
| 94 |
+
|
| 95 |
+
@classmethod
|
| 96 |
+
def from_dataclass(cls, dataclass_config):
|
| 97 |
+
"""Dynamically generate config from dataclass"""
|
| 98 |
+
if not is_dataclass(dataclass_config):
|
| 99 |
+
raise ValueError("Input must be a dataclass instance")
|
| 100 |
+
|
| 101 |
+
dataclass_dict = asdict(dataclass_config)
|
| 102 |
+
return cls(**dataclass_dict)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@dataclass
|
| 106 |
+
class SamplingParams:
|
| 107 |
+
temperature: float = 0.6
|
| 108 |
+
repetition_penalty: float = 1.25
|
| 109 |
+
top_k: int = 100
|
| 110 |
+
top_p: float = 0.9
|
| 111 |
+
max_tokens: int = 3000
|
| 112 |
+
min_tokens: int = 8
|
| 113 |
+
stop_token_ids: list[int] = field(default_factory=lambda: [151675])
|
| 114 |
+
# RasSampler parameters
|
| 115 |
+
use_ras: bool = True
|
| 116 |
+
win_size: int = 25
|
| 117 |
+
tau_r: float = 0.2
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@dataclass
|
| 121 |
+
class Config:
|
| 122 |
+
model: str
|
| 123 |
+
max_model_len: int = 8192 # 15s prompt + 30s generated audio for 25hz audio tokenizer
|
| 124 |
+
gpu_memory_utilization: float = 0.9
|
| 125 |
+
tensor_parallel_size: int = 1
|
| 126 |
+
enforce_eager: bool = False
|
| 127 |
+
hf_config: SoulXPodcastLLMConfig | AutoConfig = field(default_factory=SoulXPodcastLLMConfig)
|
| 128 |
+
eos: int = -1
|
| 129 |
+
llm_engine: str = "hf" # support hf, nano-vllm
|
| 130 |
+
max_turn_size: int = 14
|
| 131 |
+
turn_tokens_threshold: int = 6192
|
| 132 |
+
|
| 133 |
+
prompt_context: int = 2 # default to 2 for two-speaker podcast;
|
| 134 |
+
history_context: int = 4
|
| 135 |
+
history_text_context: int = 4
|
| 136 |
+
|
| 137 |
+
def __post_init__(self):
|
| 138 |
+
assert os.path.isdir(self.model)
|
| 139 |
+
|
| 140 |
+
max_pos = getattr(self.hf_config, "max_position_embeddings", 8192)
|
| 141 |
+
self.max_model_len = min(self.max_model_len, max_pos)
|
soulxpodcast/engine/__init__.py
ADDED
|
File without changes
|
soulxpodcast/engine/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (184 Bytes). View file
|
|
|
soulxpodcast/engine/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (172 Bytes). View file
|
|
|
soulxpodcast/engine/__pycache__/llm_engine.cpython-311.pyc
ADDED
|
Binary file (7.81 kB). View file
|
|
|
soulxpodcast/engine/__pycache__/llm_engine.cpython-312.pyc
ADDED
|
Binary file (6.51 kB). View file
|
|
|
soulxpodcast/engine/llm_engine.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import types
|
| 2 |
+
import atexit
|
| 3 |
+
from dataclasses import fields, asdict
|
| 4 |
+
from time import perf_counter
|
| 5 |
+
import os
|
| 6 |
+
from functools import partial
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.multiprocessing as mp
|
| 10 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteriaList
|
| 11 |
+
from transformers import EosTokenCriteria, RepetitionPenaltyLogitsProcessor
|
| 12 |
+
try:
|
| 13 |
+
from vllm import LLM
|
| 14 |
+
from vllm import SamplingParams as VllmSamplingParams
|
| 15 |
+
from vllm.inputs import TokensPrompt as TokensPrompt
|
| 16 |
+
SUPPORT_VLLM = True
|
| 17 |
+
except ImportError:
|
| 18 |
+
SUPPORT_VLLM = False
|
| 19 |
+
|
| 20 |
+
from soulxpodcast.models.modules.sampler import _ras_sample_hf_engine
|
| 21 |
+
from soulxpodcast.config import Config, SamplingParams
|
| 22 |
+
|
| 23 |
+
class HFLLMEngine:
|
| 24 |
+
|
| 25 |
+
def __init__(self, model, **kwargs):
|
| 26 |
+
config_fields = {field.name for field in fields(Config)}
|
| 27 |
+
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
|
| 28 |
+
config = Config(model, **config_kwargs)
|
| 29 |
+
|
| 30 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
|
| 31 |
+
config.eos = config.hf_config.eos_token_id # speech eos token;
|
| 32 |
+
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 33 |
+
self.model = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.bfloat16, device_map=self.device)
|
| 34 |
+
self.config = config
|
| 35 |
+
self.pad_token_id = self.tokenizer.pad_token_id
|
| 36 |
+
|
| 37 |
+
def generate(
|
| 38 |
+
self,
|
| 39 |
+
prompt: list[str],
|
| 40 |
+
sampling_param: SamplingParams,
|
| 41 |
+
past_key_values=None,
|
| 42 |
+
) -> dict:
|
| 43 |
+
|
| 44 |
+
# Recreate stopping_criteria per request for thread safety
|
| 45 |
+
stopping_criteria = StoppingCriteriaList([EosTokenCriteria(eos_token_id=self.config.hf_config.eos_token_id)])
|
| 46 |
+
if sampling_param.use_ras:
|
| 47 |
+
sample_hf_engine_handler = partial(_ras_sample_hf_engine,
|
| 48 |
+
use_ras=sampling_param.use_ras,
|
| 49 |
+
win_size=sampling_param.win_size, tau_r=sampling_param.tau_r)
|
| 50 |
+
else:
|
| 51 |
+
sample_hf_engine_handler = None
|
| 52 |
+
rep_pen_processor = RepetitionPenaltyLogitsProcessor(
|
| 53 |
+
penalty=sampling_param.repetition_penalty,
|
| 54 |
+
prompt_ignore_length=len(prompt)
|
| 55 |
+
) # exclude the input prompt, consistent with vLLM implementation;
|
| 56 |
+
with torch.no_grad(): # Avoids gradient computation with no_grad
|
| 57 |
+
input_len = len(prompt)
|
| 58 |
+
generated_ids = self.model.generate(
|
| 59 |
+
input_ids = torch.tensor([prompt], dtype=torch.int64).to(self.device),
|
| 60 |
+
do_sample=True,
|
| 61 |
+
top_k=sampling_param.top_k,
|
| 62 |
+
top_p=sampling_param.top_p,
|
| 63 |
+
min_new_tokens=sampling_param.min_tokens,
|
| 64 |
+
max_new_tokens=sampling_param.max_tokens,
|
| 65 |
+
temperature=sampling_param.temperature,
|
| 66 |
+
repetition_penalty=sampling_param.repetition_penalty,
|
| 67 |
+
stopping_criteria=stopping_criteria,
|
| 68 |
+
past_key_values=past_key_values,
|
| 69 |
+
custom_generate=sample_hf_engine_handler,
|
| 70 |
+
use_cache=True,
|
| 71 |
+
logits_processor=[rep_pen_processor]
|
| 72 |
+
)
|
| 73 |
+
generated_ids = generated_ids[:, input_len:].cpu().numpy().tolist()[0]
|
| 74 |
+
output = {
|
| 75 |
+
"text": self.tokenizer.decode(generated_ids),
|
| 76 |
+
"token_ids": generated_ids,
|
| 77 |
+
}
|
| 78 |
+
return output
|
| 79 |
+
|
| 80 |
+
class VLLMEngine:
|
| 81 |
+
|
| 82 |
+
def __init__(self, model, **kwargs):
|
| 83 |
+
|
| 84 |
+
config_fields = {field.name for field in fields(Config)}
|
| 85 |
+
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
|
| 86 |
+
config = Config(model, **config_kwargs)
|
| 87 |
+
|
| 88 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
| 89 |
+
config.eos = config.hf_config.eos_token_id # speech eos token;
|
| 90 |
+
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 91 |
+
os.environ["VLLM_USE_V1"] = "0"
|
| 92 |
+
if SUPPORT_VLLM:
|
| 93 |
+
self.model = LLM(model=model, enforce_eager=True, dtype="bfloat16", max_model_len=8192, enable_prefix_caching=True)
|
| 94 |
+
else:
|
| 95 |
+
raise ImportError("Not Support VLLM now!!!")
|
| 96 |
+
self.config = config
|
| 97 |
+
self.pad_token_id = self.tokenizer.pad_token_id
|
| 98 |
+
|
| 99 |
+
def generate(
|
| 100 |
+
self,
|
| 101 |
+
prompt: list[str],
|
| 102 |
+
sampling_param: SamplingParams,
|
| 103 |
+
past_key_values=None,
|
| 104 |
+
) -> dict:
|
| 105 |
+
sampling_param.stop_token_ids = [self.config.hf_config.eos_token_id]
|
| 106 |
+
with torch.no_grad(): # Avoids gradient computation with no_grad
|
| 107 |
+
generated_ids = self.model.generate(
|
| 108 |
+
TokensPrompt(prompt_token_ids=prompt),
|
| 109 |
+
VllmSamplingParams(**asdict(sampling_param)),
|
| 110 |
+
use_tqdm=False,
|
| 111 |
+
)[0].outputs[0].token_ids
|
| 112 |
+
output = {
|
| 113 |
+
"text": self.tokenizer.decode(generated_ids),
|
| 114 |
+
"token_ids": list(generated_ids),
|
| 115 |
+
}
|
| 116 |
+
return output
|
soulxpodcast/models/__pycache__/soulxpodcast.cpython-311.pyc
ADDED
|
Binary file (12.7 kB). View file
|
|
|
soulxpodcast/models/__pycache__/soulxpodcast.cpython-312.pyc
ADDED
|
Binary file (11 kB). View file
|
|
|
soulxpodcast/models/modules/__init__.py
ADDED
|
File without changes
|
soulxpodcast/models/modules/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (192 Bytes). View file
|
|
|
soulxpodcast/models/modules/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (180 Bytes). View file
|
|
|
soulxpodcast/models/modules/__pycache__/flow.cpython-311.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
soulxpodcast/models/modules/__pycache__/flow.cpython-312.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
soulxpodcast/models/modules/__pycache__/hifigan.cpython-311.pyc
ADDED
|
Binary file (15 kB). View file
|
|
|
soulxpodcast/models/modules/__pycache__/hifigan.cpython-312.pyc
ADDED
|
Binary file (13.8 kB). View file
|
|
|
soulxpodcast/models/modules/__pycache__/sampler.cpython-311.pyc
ADDED
|
Binary file (9.64 kB). View file
|
|
|
soulxpodcast/models/modules/__pycache__/sampler.cpython-312.pyc
ADDED
|
Binary file (9.25 kB). View file
|
|
|
soulxpodcast/models/modules/flow.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from soulxpodcast.models.modules.flow_components.estimator import \
|
| 8 |
+
CausalConditionalDecoder
|
| 9 |
+
from soulxpodcast.models.modules.flow_components.upsample_encoder import (
|
| 10 |
+
UpsampleConformerEncoder, make_pad_mask)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class CfmParams:
|
| 15 |
+
sigma_min: float = 1e-6
|
| 16 |
+
solver: str = "euler"
|
| 17 |
+
t_scheduler: str = "cosine"
|
| 18 |
+
training_cfg_rate: float = 0.2
|
| 19 |
+
inference_cfg_rate: float = 0.7
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class CausalConditionalCFM(torch.nn.Module):
|
| 23 |
+
def __init__(self, in_channels=320, cfm_params=CfmParams(), n_spks=1, spk_emb_dim=80, estimator: torch.nn.Module = None):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.n_feats = in_channels
|
| 26 |
+
self.n_spks = n_spks
|
| 27 |
+
self.spk_emb_dim = spk_emb_dim
|
| 28 |
+
self.solver = cfm_params.solver
|
| 29 |
+
if hasattr(cfm_params, "sigma_min"):
|
| 30 |
+
self.sigma_min = cfm_params.sigma_min
|
| 31 |
+
else:
|
| 32 |
+
self.sigma_min = 1e-4
|
| 33 |
+
self.t_scheduler = cfm_params.t_scheduler
|
| 34 |
+
self.training_cfg_rate = cfm_params.training_cfg_rate
|
| 35 |
+
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
| 36 |
+
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
| 37 |
+
# Just change the architecture of the estimator here
|
| 38 |
+
self.estimator = CausalConditionalDecoder() if estimator is None else estimator
|
| 39 |
+
|
| 40 |
+
@torch.inference_mode()
|
| 41 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
|
| 42 |
+
"""Forward diffusion
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
mu (torch.Tensor): output of encoder
|
| 46 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 47 |
+
mask (torch.Tensor): output_mask
|
| 48 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 49 |
+
n_timesteps (int): number of diffusion steps
|
| 50 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
| 51 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 52 |
+
shape: (batch_size, spk_emb_dim)
|
| 53 |
+
cond: Not used but kept for future purposes
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
sample: generated mel-spectrogram
|
| 57 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 58 |
+
"""
|
| 59 |
+
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
|
| 60 |
+
# fix prompt and overlap part mu and z
|
| 61 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
| 62 |
+
if self.t_scheduler == 'cosine':
|
| 63 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
| 64 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None
|
| 65 |
+
|
| 66 |
+
def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
|
| 67 |
+
"""
|
| 68 |
+
Fixed euler solver for ODEs.
|
| 69 |
+
Args:
|
| 70 |
+
x (torch.Tensor): random noise
|
| 71 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
| 72 |
+
shape: (n_timesteps + 1,)
|
| 73 |
+
mu (torch.Tensor): output of encoder
|
| 74 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 75 |
+
mask (torch.Tensor): output_mask
|
| 76 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 77 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 78 |
+
shape: (batch_size, spk_emb_dim)
|
| 79 |
+
cond: Not used but kept for future purposes
|
| 80 |
+
"""
|
| 81 |
+
batch_size = x.size(0)
|
| 82 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
| 83 |
+
|
| 84 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
| 85 |
+
# Or in future might add like a return_all_steps flag
|
| 86 |
+
sol = []
|
| 87 |
+
|
| 88 |
+
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
| 89 |
+
# Create tensors with double batch size for CFG (conditional + unconditional)
|
| 90 |
+
x_in = torch.zeros([batch_size * 2, x.size(1), x.size(2)], device=x.device, dtype=x.dtype)
|
| 91 |
+
mask_in = torch.zeros([batch_size * 2, mask.size(1), mask.size(2)], device=x.device, dtype=x.dtype)
|
| 92 |
+
mu_in = torch.zeros([batch_size * 2, mu.size(1), mu.size(2)], device=x.device, dtype=x.dtype)
|
| 93 |
+
t_in = torch.zeros([batch_size * 2], device=x.device, dtype=x.dtype)
|
| 94 |
+
spks_in = torch.zeros([batch_size * 2, spks.size(1)], device=x.device, dtype=x.dtype)
|
| 95 |
+
cond_in = torch.zeros([batch_size * 2, cond.size(1), cond.size(2)], device=x.device, dtype=x.dtype)
|
| 96 |
+
|
| 97 |
+
for step in range(1, len(t_span)):
|
| 98 |
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
| 99 |
+
# Copy conditional and unconditional input
|
| 100 |
+
x_in[:batch_size] = x
|
| 101 |
+
x_in[batch_size:] = x
|
| 102 |
+
mask_in[:batch_size] = mask
|
| 103 |
+
mask_in[batch_size:] = mask
|
| 104 |
+
mu_in[:batch_size] = mu
|
| 105 |
+
# Unconditional part remains 0
|
| 106 |
+
t_in.fill_(t)
|
| 107 |
+
spks_in[:batch_size] = spks
|
| 108 |
+
cond_in[:batch_size] = cond
|
| 109 |
+
|
| 110 |
+
dphi_dt = self.estimator(
|
| 111 |
+
x_in, mask_in,
|
| 112 |
+
mu_in, t_in,
|
| 113 |
+
spks_in,
|
| 114 |
+
cond_in,
|
| 115 |
+
streaming
|
| 116 |
+
)
|
| 117 |
+
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [batch_size, batch_size], dim=0)
|
| 118 |
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
| 119 |
+
x = x + dt * dphi_dt
|
| 120 |
+
t = t + dt
|
| 121 |
+
sol.append(x)
|
| 122 |
+
if step < len(t_span) - 1:
|
| 123 |
+
dt = t_span[step + 1] - t
|
| 124 |
+
|
| 125 |
+
return sol[-1].float()
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
input_size: int = 512,
|
| 132 |
+
output_size: int = 80,
|
| 133 |
+
spk_embed_dim: int = 192,
|
| 134 |
+
output_type: str = "mel",
|
| 135 |
+
vocab_size: int = 6561,
|
| 136 |
+
input_frame_rate: int = 25,
|
| 137 |
+
token_mel_ratio: int = 2,
|
| 138 |
+
pre_lookahead_len: int = 3,
|
| 139 |
+
encoder: torch.nn.Module = None,
|
| 140 |
+
decoder: torch.nn.Module = None,
|
| 141 |
+
):
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.input_size = input_size
|
| 144 |
+
self.output_size = output_size
|
| 145 |
+
self.vocab_size = vocab_size
|
| 146 |
+
self.output_type = output_type
|
| 147 |
+
self.input_frame_rate = input_frame_rate
|
| 148 |
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
| 149 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
| 150 |
+
self.encoder = UpsampleConformerEncoder() if encoder is None else encoder
|
| 151 |
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
| 152 |
+
self.decoder = CausalConditionalCFM() if decoder is None else decoder
|
| 153 |
+
self.token_mel_ratio = token_mel_ratio
|
| 154 |
+
self.pre_lookahead_len = pre_lookahead_len
|
| 155 |
+
|
| 156 |
+
@torch.inference_mode()
|
| 157 |
+
def forward(self,
|
| 158 |
+
token,
|
| 159 |
+
token_len,
|
| 160 |
+
prompt_feat,
|
| 161 |
+
prompt_feat_len,
|
| 162 |
+
embedding,
|
| 163 |
+
streaming,
|
| 164 |
+
finalize):
|
| 165 |
+
# xvec projection
|
| 166 |
+
embedding = F.normalize(embedding, dim=1)
|
| 167 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
| 168 |
+
|
| 169 |
+
# concat text and prompt_text
|
| 170 |
+
mask = (~make_pad_mask(token_len, max_len=token.shape[1])).unsqueeze(-1).to(embedding)
|
| 171 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
| 172 |
+
|
| 173 |
+
# text encode
|
| 174 |
+
if finalize is True:
|
| 175 |
+
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
|
| 176 |
+
else:
|
| 177 |
+
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
|
| 178 |
+
h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
|
| 179 |
+
h = self.encoder_proj(h)
|
| 180 |
+
|
| 181 |
+
# get conditions
|
| 182 |
+
conds = torch.zeros_like(h, device=token.device)
|
| 183 |
+
for i, j in enumerate(prompt_feat_len):
|
| 184 |
+
conds[i, :j] = prompt_feat[i, :j]
|
| 185 |
+
conds = conds.transpose(1, 2)
|
| 186 |
+
|
| 187 |
+
h_lengths = h_lengths.sum(dim=-1).squeeze(dim=1)
|
| 188 |
+
mask = (~make_pad_mask(h_lengths, max_len=h.shape[1])).to(h)
|
| 189 |
+
feat, _ = self.decoder(
|
| 190 |
+
mu=h.transpose(1, 2).contiguous(),
|
| 191 |
+
mask=mask.unsqueeze(1),
|
| 192 |
+
spks=embedding,
|
| 193 |
+
cond=conds,
|
| 194 |
+
n_timesteps=15,
|
| 195 |
+
streaming=streaming
|
| 196 |
+
) # [B, num_mels, T]
|
| 197 |
+
return feat.float(), h_lengths
|
soulxpodcast/models/modules/flow_components/__init__.py
ADDED
|
File without changes
|
soulxpodcast/models/modules/flow_components/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (208 Bytes). View file
|
|
|
soulxpodcast/models/modules/flow_components/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (196 Bytes). View file
|
|
|
soulxpodcast/models/modules/flow_components/__pycache__/estimator.cpython-311.pyc
ADDED
|
Binary file (49.4 kB). View file
|
|
|
soulxpodcast/models/modules/flow_components/__pycache__/estimator.cpython-312.pyc
ADDED
|
Binary file (43.6 kB). View file
|
|
|
soulxpodcast/models/modules/flow_components/__pycache__/upsample_encoder.cpython-311.pyc
ADDED
|
Binary file (52 kB). View file
|
|
|
soulxpodcast/models/modules/flow_components/__pycache__/upsample_encoder.cpython-312.pyc
ADDED
|
Binary file (49.6 kB). View file
|
|
|
soulxpodcast/models/modules/flow_components/estimator.py
ADDED
|
@@ -0,0 +1,974 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Any, Dict, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from diffusers.models.attention import (GEGLU, GELU, AdaLayerNorm,
|
| 8 |
+
AdaLayerNormZero, ApproximateGELU)
|
| 9 |
+
from diffusers.models.attention_processor import Attention
|
| 10 |
+
from diffusers.models.lora import LoRACompatibleLinear
|
| 11 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 12 |
+
from einops import pack, rearrange, repeat
|
| 13 |
+
|
| 14 |
+
from soulxpodcast.models.modules.flow_components.upsample_encoder import \
|
| 15 |
+
add_optional_chunk_mask
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
| 19 |
+
assert mask.dtype == torch.bool
|
| 20 |
+
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
|
| 21 |
+
mask = mask.to(dtype)
|
| 22 |
+
# attention mask bias
|
| 23 |
+
# NOTE(Mddct): torch.finfo jit issues
|
| 24 |
+
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
|
| 25 |
+
mask = (1.0 - mask) * -1.0e+10
|
| 26 |
+
return mask
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class SnakeBeta(nn.Module):
|
| 30 |
+
"""
|
| 31 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
| 32 |
+
Shape:
|
| 33 |
+
- Input: (B, C, T)
|
| 34 |
+
- Output: (B, C, T), same shape as the input
|
| 35 |
+
Parameters:
|
| 36 |
+
- alpha - trainable parameter that controls frequency
|
| 37 |
+
- beta - trainable parameter that controls magnitude
|
| 38 |
+
References:
|
| 39 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 40 |
+
https://arxiv.org/abs/2006.08195
|
| 41 |
+
Examples:
|
| 42 |
+
>>> a1 = snakebeta(256)
|
| 43 |
+
>>> x = torch.randn(256)
|
| 44 |
+
>>> x = a1(x)
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
in_features: shape of the input
|
| 48 |
+
out_features: shape of the output
|
| 49 |
+
alpha: trainable parameter that controls frequency
|
| 50 |
+
alpha_trainable: whether alpha is trainable
|
| 51 |
+
alpha_logscale: whether to use log scale for alpha
|
| 52 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
| 53 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
| 54 |
+
alpha will be trained along with the rest of your model.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.in_features = out_features if isinstance(out_features, list) else [out_features]
|
| 60 |
+
self.proj = LoRACompatibleLinear(in_features, out_features)
|
| 61 |
+
|
| 62 |
+
# initialize alpha
|
| 63 |
+
self.alpha_logscale = alpha_logscale
|
| 64 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 65 |
+
self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
|
| 66 |
+
self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
|
| 67 |
+
else: # linear scale alphas initialized to ones
|
| 68 |
+
self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
|
| 69 |
+
self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
|
| 70 |
+
|
| 71 |
+
self.alpha.requires_grad = alpha_trainable
|
| 72 |
+
self.beta.requires_grad = alpha_trainable
|
| 73 |
+
|
| 74 |
+
self.no_div_by_zero = 0.000000001
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
"""
|
| 78 |
+
Forward pass of the function.
|
| 79 |
+
Applies the function to the input elementwise.
|
| 80 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
| 81 |
+
"""
|
| 82 |
+
x = self.proj(x)
|
| 83 |
+
if self.alpha_logscale:
|
| 84 |
+
alpha = torch.exp(self.alpha)
|
| 85 |
+
beta = torch.exp(self.beta)
|
| 86 |
+
else:
|
| 87 |
+
alpha = self.alpha
|
| 88 |
+
beta = self.beta
|
| 89 |
+
|
| 90 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
|
| 91 |
+
|
| 92 |
+
return x
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class FeedForward(nn.Module):
|
| 96 |
+
r"""
|
| 97 |
+
A feed-forward layer.
|
| 98 |
+
|
| 99 |
+
Parameters:
|
| 100 |
+
dim (`int`): The number of channels in the input.
|
| 101 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
| 102 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
| 103 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 104 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
| 105 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
dim: int,
|
| 111 |
+
dim_out: Optional[int] = None,
|
| 112 |
+
mult: int = 4,
|
| 113 |
+
dropout: float = 0.0,
|
| 114 |
+
activation_fn: str = "geglu",
|
| 115 |
+
final_dropout: bool = False,
|
| 116 |
+
):
|
| 117 |
+
super().__init__()
|
| 118 |
+
inner_dim = int(dim * mult)
|
| 119 |
+
dim_out = dim_out if dim_out is not None else dim
|
| 120 |
+
|
| 121 |
+
if activation_fn == "gelu":
|
| 122 |
+
act_fn = GELU(dim, inner_dim)
|
| 123 |
+
if activation_fn == "gelu-approximate":
|
| 124 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
| 125 |
+
elif activation_fn == "geglu":
|
| 126 |
+
act_fn = GEGLU(dim, inner_dim)
|
| 127 |
+
elif activation_fn == "geglu-approximate":
|
| 128 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
| 129 |
+
elif activation_fn == "snakebeta":
|
| 130 |
+
act_fn = SnakeBeta(dim, inner_dim)
|
| 131 |
+
|
| 132 |
+
self.net = nn.ModuleList([])
|
| 133 |
+
# project in
|
| 134 |
+
self.net.append(act_fn)
|
| 135 |
+
# project dropout
|
| 136 |
+
self.net.append(nn.Dropout(dropout))
|
| 137 |
+
# project out
|
| 138 |
+
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
|
| 139 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
| 140 |
+
if final_dropout:
|
| 141 |
+
self.net.append(nn.Dropout(dropout))
|
| 142 |
+
|
| 143 |
+
def forward(self, hidden_states):
|
| 144 |
+
for module in self.net:
|
| 145 |
+
hidden_states = module(hidden_states)
|
| 146 |
+
return hidden_states
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
@maybe_allow_in_graph
|
| 150 |
+
class BasicTransformerBlock(nn.Module):
|
| 151 |
+
r"""
|
| 152 |
+
A basic Transformer block.
|
| 153 |
+
|
| 154 |
+
Parameters:
|
| 155 |
+
dim (`int`): The number of channels in the input and output.
|
| 156 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
| 157 |
+
attention_head_dim (`int`): The number of channels in each head.
|
| 158 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 159 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
| 160 |
+
only_cross_attention (`bool`, *optional*):
|
| 161 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
| 162 |
+
double_self_attention (`bool`, *optional*):
|
| 163 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
| 164 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
| 165 |
+
num_embeds_ada_norm (:
|
| 166 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
| 167 |
+
attention_bias (:
|
| 168 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
def __init__(
|
| 172 |
+
self,
|
| 173 |
+
dim: int,
|
| 174 |
+
num_attention_heads: int,
|
| 175 |
+
attention_head_dim: int,
|
| 176 |
+
dropout=0.0,
|
| 177 |
+
cross_attention_dim: Optional[int] = None,
|
| 178 |
+
activation_fn: str = "geglu",
|
| 179 |
+
num_embeds_ada_norm: Optional[int] = None,
|
| 180 |
+
attention_bias: bool = False,
|
| 181 |
+
only_cross_attention: bool = False,
|
| 182 |
+
double_self_attention: bool = False,
|
| 183 |
+
upcast_attention: bool = False,
|
| 184 |
+
norm_elementwise_affine: bool = True,
|
| 185 |
+
norm_type: str = "layer_norm",
|
| 186 |
+
final_dropout: bool = False,
|
| 187 |
+
):
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.only_cross_attention = only_cross_attention
|
| 190 |
+
|
| 191 |
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
| 192 |
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
| 193 |
+
|
| 194 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
| 195 |
+
raise ValueError(
|
| 196 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
| 197 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
| 201 |
+
# 1. Self-Attn
|
| 202 |
+
if self.use_ada_layer_norm:
|
| 203 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
| 204 |
+
elif self.use_ada_layer_norm_zero:
|
| 205 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
| 206 |
+
else:
|
| 207 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
| 208 |
+
self.attn1 = Attention(
|
| 209 |
+
query_dim=dim,
|
| 210 |
+
heads=num_attention_heads,
|
| 211 |
+
dim_head=attention_head_dim,
|
| 212 |
+
dropout=dropout,
|
| 213 |
+
bias=attention_bias,
|
| 214 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
| 215 |
+
upcast_attention=upcast_attention,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# 2. Cross-Attn
|
| 219 |
+
if cross_attention_dim is not None or double_self_attention:
|
| 220 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
| 221 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
| 222 |
+
# the second cross attention block.
|
| 223 |
+
self.norm2 = (
|
| 224 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
| 225 |
+
if self.use_ada_layer_norm
|
| 226 |
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
| 227 |
+
)
|
| 228 |
+
self.attn2 = Attention(
|
| 229 |
+
query_dim=dim,
|
| 230 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
| 231 |
+
heads=num_attention_heads,
|
| 232 |
+
dim_head=attention_head_dim,
|
| 233 |
+
dropout=dropout,
|
| 234 |
+
bias=attention_bias,
|
| 235 |
+
upcast_attention=upcast_attention,
|
| 236 |
+
# scale_qk=False, # uncomment this to not to use flash attention
|
| 237 |
+
) # is self-attn if encoder_hidden_states is none
|
| 238 |
+
else:
|
| 239 |
+
self.norm2 = None
|
| 240 |
+
self.attn2 = None
|
| 241 |
+
|
| 242 |
+
# 3. Feed-forward
|
| 243 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
| 244 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
| 245 |
+
|
| 246 |
+
# let chunk size default to None
|
| 247 |
+
self._chunk_size = None
|
| 248 |
+
self._chunk_dim = 0
|
| 249 |
+
|
| 250 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
| 251 |
+
# Sets chunk feed-forward
|
| 252 |
+
self._chunk_size = chunk_size
|
| 253 |
+
self._chunk_dim = dim
|
| 254 |
+
|
| 255 |
+
def forward(
|
| 256 |
+
self,
|
| 257 |
+
hidden_states: torch.FloatTensor,
|
| 258 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 259 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 260 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 261 |
+
timestep: Optional[torch.LongTensor] = None,
|
| 262 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
| 263 |
+
class_labels: Optional[torch.LongTensor] = None,
|
| 264 |
+
):
|
| 265 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
| 266 |
+
# 1. Self-Attention
|
| 267 |
+
if self.use_ada_layer_norm:
|
| 268 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
| 269 |
+
elif self.use_ada_layer_norm_zero:
|
| 270 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
| 271 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
| 272 |
+
)
|
| 273 |
+
else:
|
| 274 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 275 |
+
|
| 276 |
+
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
| 277 |
+
|
| 278 |
+
attn_output = self.attn1(
|
| 279 |
+
norm_hidden_states,
|
| 280 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
| 281 |
+
attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
|
| 282 |
+
**cross_attention_kwargs,
|
| 283 |
+
)
|
| 284 |
+
if self.use_ada_layer_norm_zero:
|
| 285 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
| 286 |
+
hidden_states = attn_output + hidden_states
|
| 287 |
+
|
| 288 |
+
# 2. Cross-Attention
|
| 289 |
+
if self.attn2 is not None:
|
| 290 |
+
norm_hidden_states = (
|
| 291 |
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
attn_output = self.attn2(
|
| 295 |
+
norm_hidden_states,
|
| 296 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 297 |
+
attention_mask=encoder_attention_mask,
|
| 298 |
+
**cross_attention_kwargs,
|
| 299 |
+
)
|
| 300 |
+
hidden_states = attn_output + hidden_states
|
| 301 |
+
|
| 302 |
+
# 3. Feed-forward
|
| 303 |
+
norm_hidden_states = self.norm3(hidden_states)
|
| 304 |
+
|
| 305 |
+
if self.use_ada_layer_norm_zero:
|
| 306 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 307 |
+
|
| 308 |
+
if self._chunk_size is not None:
|
| 309 |
+
# "feed_forward_chunk_size" can be used to save memory
|
| 310 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
| 311 |
+
raise ValueError(
|
| 312 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
| 316 |
+
ff_output = torch.cat(
|
| 317 |
+
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
| 318 |
+
dim=self._chunk_dim,
|
| 319 |
+
)
|
| 320 |
+
else:
|
| 321 |
+
ff_output = self.ff(norm_hidden_states)
|
| 322 |
+
|
| 323 |
+
if self.use_ada_layer_norm_zero:
|
| 324 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
| 325 |
+
|
| 326 |
+
hidden_states = ff_output + hidden_states
|
| 327 |
+
|
| 328 |
+
return hidden_states
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
class SinusoidalPosEmb(torch.nn.Module):
|
| 332 |
+
def __init__(self, dim):
|
| 333 |
+
super().__init__()
|
| 334 |
+
self.dim = dim
|
| 335 |
+
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
| 336 |
+
|
| 337 |
+
def forward(self, x, scale=1000):
|
| 338 |
+
if x.ndim < 1:
|
| 339 |
+
x = x.unsqueeze(0)
|
| 340 |
+
device = x.device
|
| 341 |
+
half_dim = self.dim // 2
|
| 342 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 343 |
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
| 344 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
| 345 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 346 |
+
return emb
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
class Block1D(torch.nn.Module):
|
| 350 |
+
def __init__(self, dim, dim_out, groups=8):
|
| 351 |
+
super().__init__()
|
| 352 |
+
self.block = torch.nn.Sequential(
|
| 353 |
+
torch.nn.Conv1d(dim, dim_out, 3, padding=1),
|
| 354 |
+
torch.nn.GroupNorm(groups, dim_out),
|
| 355 |
+
nn.Mish(),
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
def forward(self, x, mask):
|
| 359 |
+
output = self.block(x * mask)
|
| 360 |
+
return output * mask
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
class ResnetBlock1D(torch.nn.Module):
|
| 364 |
+
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
| 365 |
+
super().__init__()
|
| 366 |
+
self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
|
| 367 |
+
|
| 368 |
+
self.block1 = Block1D(dim, dim_out, groups=groups)
|
| 369 |
+
self.block2 = Block1D(dim_out, dim_out, groups=groups)
|
| 370 |
+
|
| 371 |
+
self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
|
| 372 |
+
|
| 373 |
+
def forward(self, x, mask, time_emb):
|
| 374 |
+
h = self.block1(x, mask)
|
| 375 |
+
h += self.mlp(time_emb).unsqueeze(-1)
|
| 376 |
+
h = self.block2(h, mask)
|
| 377 |
+
output = h + self.res_conv(x * mask)
|
| 378 |
+
return output
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class Downsample1D(nn.Module):
|
| 382 |
+
def __init__(self, dim):
|
| 383 |
+
super().__init__()
|
| 384 |
+
self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
|
| 385 |
+
|
| 386 |
+
def forward(self, x):
|
| 387 |
+
return self.conv(x)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
class TimestepEmbedding(nn.Module):
|
| 391 |
+
def __init__(
|
| 392 |
+
self,
|
| 393 |
+
in_channels: int,
|
| 394 |
+
time_embed_dim: int,
|
| 395 |
+
act_fn: str = "silu",
|
| 396 |
+
out_dim: int = None,
|
| 397 |
+
post_act_fn: Optional[str] = None,
|
| 398 |
+
cond_proj_dim=None,
|
| 399 |
+
):
|
| 400 |
+
super().__init__()
|
| 401 |
+
assert act_fn == "silu", "act_fn must be silu"
|
| 402 |
+
|
| 403 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
| 404 |
+
|
| 405 |
+
if cond_proj_dim is not None:
|
| 406 |
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
| 407 |
+
else:
|
| 408 |
+
self.cond_proj = None
|
| 409 |
+
|
| 410 |
+
self.act = nn.SiLU()
|
| 411 |
+
|
| 412 |
+
if out_dim is not None:
|
| 413 |
+
time_embed_dim_out = out_dim
|
| 414 |
+
else:
|
| 415 |
+
time_embed_dim_out = time_embed_dim
|
| 416 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
| 417 |
+
|
| 418 |
+
if post_act_fn is None:
|
| 419 |
+
self.post_act = None
|
| 420 |
+
else:
|
| 421 |
+
self.post_act = nn.SiLU()
|
| 422 |
+
|
| 423 |
+
def forward(self, sample, condition=None):
|
| 424 |
+
if condition is not None:
|
| 425 |
+
sample = sample + self.cond_proj(condition)
|
| 426 |
+
sample = self.linear_1(sample)
|
| 427 |
+
|
| 428 |
+
if self.act is not None:
|
| 429 |
+
sample = self.act(sample)
|
| 430 |
+
|
| 431 |
+
sample = self.linear_2(sample)
|
| 432 |
+
|
| 433 |
+
if self.post_act is not None:
|
| 434 |
+
sample = self.post_act(sample)
|
| 435 |
+
return sample
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
class Upsample1D(nn.Module):
|
| 439 |
+
"""A 1D upsampling layer with an optional convolution.
|
| 440 |
+
|
| 441 |
+
Parameters:
|
| 442 |
+
channels (`int`):
|
| 443 |
+
number of channels in the inputs and outputs.
|
| 444 |
+
use_conv (`bool`, default `False`):
|
| 445 |
+
option to use a convolution.
|
| 446 |
+
use_conv_transpose (`bool`, default `False`):
|
| 447 |
+
option to use a convolution transpose.
|
| 448 |
+
out_channels (`int`, optional):
|
| 449 |
+
number of output channels. Defaults to `channels`.
|
| 450 |
+
"""
|
| 451 |
+
|
| 452 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
|
| 453 |
+
super().__init__()
|
| 454 |
+
self.channels = channels
|
| 455 |
+
self.out_channels = out_channels or channels
|
| 456 |
+
self.use_conv = use_conv
|
| 457 |
+
self.use_conv_transpose = use_conv_transpose
|
| 458 |
+
self.name = name
|
| 459 |
+
|
| 460 |
+
self.conv = None
|
| 461 |
+
if use_conv_transpose:
|
| 462 |
+
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
| 463 |
+
elif use_conv:
|
| 464 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
| 465 |
+
|
| 466 |
+
def forward(self, inputs):
|
| 467 |
+
assert inputs.shape[1] == self.channels
|
| 468 |
+
if self.use_conv_transpose:
|
| 469 |
+
return self.conv(inputs)
|
| 470 |
+
|
| 471 |
+
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
| 472 |
+
|
| 473 |
+
if self.use_conv:
|
| 474 |
+
outputs = self.conv(outputs)
|
| 475 |
+
|
| 476 |
+
return outputs
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
class Transpose(torch.nn.Module):
|
| 480 |
+
def __init__(self, dim0: int, dim1: int):
|
| 481 |
+
super().__init__()
|
| 482 |
+
self.dim0 = dim0
|
| 483 |
+
self.dim1 = dim1
|
| 484 |
+
|
| 485 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 486 |
+
x = torch.transpose(x, self.dim0, self.dim1)
|
| 487 |
+
return x
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
class CausalConv1d(torch.nn.Conv1d):
|
| 491 |
+
def __init__(
|
| 492 |
+
self,
|
| 493 |
+
in_channels: int,
|
| 494 |
+
out_channels: int,
|
| 495 |
+
kernel_size: int,
|
| 496 |
+
stride: int = 1,
|
| 497 |
+
dilation: int = 1,
|
| 498 |
+
groups: int = 1,
|
| 499 |
+
bias: bool = True,
|
| 500 |
+
padding_mode: str = 'zeros',
|
| 501 |
+
device=None,
|
| 502 |
+
dtype=None
|
| 503 |
+
) -> None:
|
| 504 |
+
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
| 505 |
+
kernel_size, stride,
|
| 506 |
+
padding=0, dilation=dilation,
|
| 507 |
+
groups=groups, bias=bias,
|
| 508 |
+
padding_mode=padding_mode,
|
| 509 |
+
device=device, dtype=dtype)
|
| 510 |
+
assert stride == 1
|
| 511 |
+
self.causal_padding = kernel_size - 1
|
| 512 |
+
|
| 513 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 514 |
+
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
| 515 |
+
x = super(CausalConv1d, self).forward(x)
|
| 516 |
+
return x
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
class CausalBlock1D(Block1D):
|
| 520 |
+
def __init__(self, dim: int, dim_out: int):
|
| 521 |
+
super(CausalBlock1D, self).__init__(dim, dim_out)
|
| 522 |
+
self.block = torch.nn.Sequential(
|
| 523 |
+
CausalConv1d(dim, dim_out, 3),
|
| 524 |
+
Transpose(1, 2),
|
| 525 |
+
nn.LayerNorm(dim_out),
|
| 526 |
+
Transpose(1, 2),
|
| 527 |
+
nn.Mish(),
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 531 |
+
output = self.block(x * mask)
|
| 532 |
+
return output * mask
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
class CausalResnetBlock1D(ResnetBlock1D):
|
| 536 |
+
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
| 537 |
+
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
| 538 |
+
self.block1 = CausalBlock1D(dim, dim_out)
|
| 539 |
+
self.block2 = CausalBlock1D(dim_out, dim_out)
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
class ConditionalDecoder(nn.Module):
|
| 543 |
+
"""
|
| 544 |
+
This decoder requires an input with the same shape of the target. So, if your text content
|
| 545 |
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
| 546 |
+
|
| 547 |
+
Args:
|
| 548 |
+
in_channels: number of input channels
|
| 549 |
+
out_channels: number of output channels
|
| 550 |
+
channels: tuple of channel dimensions
|
| 551 |
+
dropout: dropout rate
|
| 552 |
+
attention_head_dim: dimension of attention heads
|
| 553 |
+
n_blocks: number of transformer blocks
|
| 554 |
+
num_mid_blocks: number of middle blocks
|
| 555 |
+
num_heads: number of attention heads
|
| 556 |
+
act_fn: activation function name
|
| 557 |
+
"""
|
| 558 |
+
|
| 559 |
+
def __init__(
|
| 560 |
+
self,
|
| 561 |
+
in_channels,
|
| 562 |
+
out_channels,
|
| 563 |
+
channels=(256, 256),
|
| 564 |
+
dropout=0.05,
|
| 565 |
+
attention_head_dim=64,
|
| 566 |
+
n_blocks=1,
|
| 567 |
+
num_mid_blocks=2,
|
| 568 |
+
num_heads=4,
|
| 569 |
+
act_fn="snake",
|
| 570 |
+
):
|
| 571 |
+
super().__init__()
|
| 572 |
+
channels = tuple(channels)
|
| 573 |
+
self.in_channels = in_channels
|
| 574 |
+
self.out_channels = out_channels
|
| 575 |
+
|
| 576 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
| 577 |
+
time_embed_dim = channels[0] * 4
|
| 578 |
+
self.time_mlp = TimestepEmbedding(
|
| 579 |
+
in_channels=in_channels,
|
| 580 |
+
time_embed_dim=time_embed_dim,
|
| 581 |
+
act_fn="silu",
|
| 582 |
+
)
|
| 583 |
+
self.down_blocks = nn.ModuleList([])
|
| 584 |
+
self.mid_blocks = nn.ModuleList([])
|
| 585 |
+
self.up_blocks = nn.ModuleList([])
|
| 586 |
+
|
| 587 |
+
output_channel = in_channels
|
| 588 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
| 589 |
+
input_channel = output_channel
|
| 590 |
+
output_channel = channels[i]
|
| 591 |
+
is_last = i == len(channels) - 1
|
| 592 |
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
| 593 |
+
transformer_blocks = nn.ModuleList(
|
| 594 |
+
[
|
| 595 |
+
BasicTransformerBlock(
|
| 596 |
+
dim=output_channel,
|
| 597 |
+
num_attention_heads=num_heads,
|
| 598 |
+
attention_head_dim=attention_head_dim,
|
| 599 |
+
dropout=dropout,
|
| 600 |
+
activation_fn=act_fn,
|
| 601 |
+
)
|
| 602 |
+
for _ in range(n_blocks)
|
| 603 |
+
]
|
| 604 |
+
)
|
| 605 |
+
downsample = (
|
| 606 |
+
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
| 607 |
+
)
|
| 608 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
| 609 |
+
|
| 610 |
+
for _ in range(num_mid_blocks):
|
| 611 |
+
input_channel = channels[-1]
|
| 612 |
+
out_channels = channels[-1]
|
| 613 |
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
| 614 |
+
|
| 615 |
+
transformer_blocks = nn.ModuleList(
|
| 616 |
+
[
|
| 617 |
+
BasicTransformerBlock(
|
| 618 |
+
dim=output_channel,
|
| 619 |
+
num_attention_heads=num_heads,
|
| 620 |
+
attention_head_dim=attention_head_dim,
|
| 621 |
+
dropout=dropout,
|
| 622 |
+
activation_fn=act_fn,
|
| 623 |
+
)
|
| 624 |
+
for _ in range(n_blocks)
|
| 625 |
+
]
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
| 629 |
+
|
| 630 |
+
channels = channels[::-1] + (channels[0],)
|
| 631 |
+
for i in range(len(channels) - 1):
|
| 632 |
+
input_channel = channels[i] * 2
|
| 633 |
+
output_channel = channels[i + 1]
|
| 634 |
+
is_last = i == len(channels) - 2
|
| 635 |
+
resnet = ResnetBlock1D(
|
| 636 |
+
dim=input_channel,
|
| 637 |
+
dim_out=output_channel,
|
| 638 |
+
time_emb_dim=time_embed_dim,
|
| 639 |
+
)
|
| 640 |
+
transformer_blocks = nn.ModuleList(
|
| 641 |
+
[
|
| 642 |
+
BasicTransformerBlock(
|
| 643 |
+
dim=output_channel,
|
| 644 |
+
num_attention_heads=num_heads,
|
| 645 |
+
attention_head_dim=attention_head_dim,
|
| 646 |
+
dropout=dropout,
|
| 647 |
+
activation_fn=act_fn,
|
| 648 |
+
)
|
| 649 |
+
for _ in range(n_blocks)
|
| 650 |
+
]
|
| 651 |
+
)
|
| 652 |
+
upsample = (
|
| 653 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
| 654 |
+
if not is_last
|
| 655 |
+
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
| 656 |
+
)
|
| 657 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
| 658 |
+
self.final_block = Block1D(channels[-1], channels[-1])
|
| 659 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
| 660 |
+
self.initialize_weights()
|
| 661 |
+
|
| 662 |
+
def initialize_weights(self):
|
| 663 |
+
for m in self.modules():
|
| 664 |
+
if isinstance(m, nn.Conv1d):
|
| 665 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
| 666 |
+
if m.bias is not None:
|
| 667 |
+
nn.init.constant_(m.bias, 0)
|
| 668 |
+
elif isinstance(m, nn.GroupNorm):
|
| 669 |
+
nn.init.constant_(m.weight, 1)
|
| 670 |
+
nn.init.constant_(m.bias, 0)
|
| 671 |
+
elif isinstance(m, nn.Linear):
|
| 672 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
| 673 |
+
if m.bias is not None:
|
| 674 |
+
nn.init.constant_(m.bias, 0)
|
| 675 |
+
|
| 676 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
| 677 |
+
"""Forward pass of the UNet1DConditional model.
|
| 678 |
+
|
| 679 |
+
Args:
|
| 680 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
| 681 |
+
mask (_type_): shape (batch_size, 1, time)
|
| 682 |
+
t (_type_): shape (batch_size)
|
| 683 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
| 684 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
| 685 |
+
|
| 686 |
+
Raises:
|
| 687 |
+
ValueError: _description_
|
| 688 |
+
ValueError: _description_
|
| 689 |
+
|
| 690 |
+
Returns:
|
| 691 |
+
_type_: _description_
|
| 692 |
+
"""
|
| 693 |
+
|
| 694 |
+
t = self.time_embeddings(t).to(t.dtype)
|
| 695 |
+
t = self.time_mlp(t)
|
| 696 |
+
|
| 697 |
+
x = pack([x, mu], "b * t")[0]
|
| 698 |
+
|
| 699 |
+
if spks is not None:
|
| 700 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
| 701 |
+
x = pack([x, spks], "b * t")[0]
|
| 702 |
+
if cond is not None:
|
| 703 |
+
x = pack([x, cond], "b * t")[0]
|
| 704 |
+
|
| 705 |
+
hiddens = []
|
| 706 |
+
masks = [mask]
|
| 707 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
| 708 |
+
mask_down = masks[-1]
|
| 709 |
+
x = resnet(x, mask_down, t)
|
| 710 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 711 |
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
| 712 |
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
| 713 |
+
for transformer_block in transformer_blocks:
|
| 714 |
+
x = transformer_block(
|
| 715 |
+
hidden_states=x,
|
| 716 |
+
attention_mask=attn_mask,
|
| 717 |
+
timestep=t,
|
| 718 |
+
)
|
| 719 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 720 |
+
hiddens.append(x) # Save hidden states for skip connections
|
| 721 |
+
x = downsample(x * mask_down)
|
| 722 |
+
masks.append(mask_down[:, :, ::2])
|
| 723 |
+
masks = masks[:-1]
|
| 724 |
+
mask_mid = masks[-1]
|
| 725 |
+
|
| 726 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
| 727 |
+
x = resnet(x, mask_mid, t)
|
| 728 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 729 |
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
| 730 |
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
| 731 |
+
for transformer_block in transformer_blocks:
|
| 732 |
+
x = transformer_block(
|
| 733 |
+
hidden_states=x,
|
| 734 |
+
attention_mask=attn_mask,
|
| 735 |
+
timestep=t,
|
| 736 |
+
)
|
| 737 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 738 |
+
|
| 739 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
| 740 |
+
mask_up = masks.pop()
|
| 741 |
+
skip = hiddens.pop()
|
| 742 |
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
| 743 |
+
x = resnet(x, mask_up, t)
|
| 744 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 745 |
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
| 746 |
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
| 747 |
+
for transformer_block in transformer_blocks:
|
| 748 |
+
x = transformer_block(
|
| 749 |
+
hidden_states=x,
|
| 750 |
+
attention_mask=attn_mask,
|
| 751 |
+
timestep=t,
|
| 752 |
+
)
|
| 753 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 754 |
+
x = upsample(x * mask_up)
|
| 755 |
+
x = self.final_block(x, mask_up)
|
| 756 |
+
output = self.final_proj(x * mask_up)
|
| 757 |
+
return output * mask
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
class CausalConditionalDecoder(ConditionalDecoder):
|
| 761 |
+
"""
|
| 762 |
+
This decoder requires an input with the same shape of the target. So, if your text content
|
| 763 |
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
| 764 |
+
|
| 765 |
+
Args:
|
| 766 |
+
in_channels: number of input channels
|
| 767 |
+
out_channels: number of output channels
|
| 768 |
+
channels: list of channel dimensions
|
| 769 |
+
dropout: dropout rate
|
| 770 |
+
attention_head_dim: dimension of attention heads
|
| 771 |
+
n_blocks: number of transformer blocks
|
| 772 |
+
num_mid_blocks: number of middle blocks
|
| 773 |
+
num_heads: number of attention heads
|
| 774 |
+
act_fn: activation function name
|
| 775 |
+
static_chunk_size: size of static chunks
|
| 776 |
+
num_decoding_left_chunks: number of left chunks for decoding
|
| 777 |
+
"""
|
| 778 |
+
|
| 779 |
+
def __init__(
|
| 780 |
+
self,
|
| 781 |
+
in_channels=320,
|
| 782 |
+
out_channels=80,
|
| 783 |
+
channels=[256], # noqa
|
| 784 |
+
dropout=0.0,
|
| 785 |
+
attention_head_dim=64,
|
| 786 |
+
n_blocks=4,
|
| 787 |
+
num_mid_blocks=12,
|
| 788 |
+
num_heads=8,
|
| 789 |
+
act_fn="gelu",
|
| 790 |
+
static_chunk_size=50,
|
| 791 |
+
num_decoding_left_chunks=-1,
|
| 792 |
+
):
|
| 793 |
+
torch.nn.Module.__init__(self)
|
| 794 |
+
channels = tuple(channels)
|
| 795 |
+
self.in_channels = in_channels
|
| 796 |
+
self.out_channels = out_channels
|
| 797 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
| 798 |
+
time_embed_dim = channels[0] * 4
|
| 799 |
+
self.time_mlp = TimestepEmbedding(
|
| 800 |
+
in_channels=in_channels,
|
| 801 |
+
time_embed_dim=time_embed_dim,
|
| 802 |
+
act_fn="silu",
|
| 803 |
+
)
|
| 804 |
+
self.static_chunk_size = static_chunk_size
|
| 805 |
+
self.num_decoding_left_chunks = num_decoding_left_chunks
|
| 806 |
+
self.down_blocks = nn.ModuleList([])
|
| 807 |
+
self.mid_blocks = nn.ModuleList([])
|
| 808 |
+
self.up_blocks = nn.ModuleList([])
|
| 809 |
+
|
| 810 |
+
output_channel = in_channels
|
| 811 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
| 812 |
+
input_channel = output_channel
|
| 813 |
+
output_channel = channels[i]
|
| 814 |
+
is_last = i == len(channels) - 1
|
| 815 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
| 816 |
+
transformer_blocks = nn.ModuleList(
|
| 817 |
+
[
|
| 818 |
+
BasicTransformerBlock(
|
| 819 |
+
dim=output_channel,
|
| 820 |
+
num_attention_heads=num_heads,
|
| 821 |
+
attention_head_dim=attention_head_dim,
|
| 822 |
+
dropout=dropout,
|
| 823 |
+
activation_fn=act_fn,
|
| 824 |
+
)
|
| 825 |
+
for _ in range(n_blocks)
|
| 826 |
+
]
|
| 827 |
+
)
|
| 828 |
+
downsample = (
|
| 829 |
+
Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
|
| 830 |
+
)
|
| 831 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
| 832 |
+
|
| 833 |
+
for _ in range(num_mid_blocks):
|
| 834 |
+
input_channel = channels[-1]
|
| 835 |
+
out_channels = channels[-1]
|
| 836 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
| 837 |
+
|
| 838 |
+
transformer_blocks = nn.ModuleList(
|
| 839 |
+
[
|
| 840 |
+
BasicTransformerBlock(
|
| 841 |
+
dim=output_channel,
|
| 842 |
+
num_attention_heads=num_heads,
|
| 843 |
+
attention_head_dim=attention_head_dim,
|
| 844 |
+
dropout=dropout,
|
| 845 |
+
activation_fn=act_fn,
|
| 846 |
+
)
|
| 847 |
+
for _ in range(n_blocks)
|
| 848 |
+
]
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
| 852 |
+
|
| 853 |
+
channels = channels[::-1] + (channels[0],)
|
| 854 |
+
for i in range(len(channels) - 1):
|
| 855 |
+
input_channel = channels[i] * 2
|
| 856 |
+
output_channel = channels[i + 1]
|
| 857 |
+
is_last = i == len(channels) - 2
|
| 858 |
+
resnet = CausalResnetBlock1D(
|
| 859 |
+
dim=input_channel,
|
| 860 |
+
dim_out=output_channel,
|
| 861 |
+
time_emb_dim=time_embed_dim,
|
| 862 |
+
)
|
| 863 |
+
transformer_blocks = nn.ModuleList(
|
| 864 |
+
[
|
| 865 |
+
BasicTransformerBlock(
|
| 866 |
+
dim=output_channel,
|
| 867 |
+
num_attention_heads=num_heads,
|
| 868 |
+
attention_head_dim=attention_head_dim,
|
| 869 |
+
dropout=dropout,
|
| 870 |
+
activation_fn=act_fn,
|
| 871 |
+
)
|
| 872 |
+
for _ in range(n_blocks)
|
| 873 |
+
]
|
| 874 |
+
)
|
| 875 |
+
upsample = (
|
| 876 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
| 877 |
+
if not is_last
|
| 878 |
+
else CausalConv1d(output_channel, output_channel, 3)
|
| 879 |
+
)
|
| 880 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
| 881 |
+
self.final_block = CausalBlock1D(channels[-1], channels[-1])
|
| 882 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
| 883 |
+
self.initialize_weights()
|
| 884 |
+
|
| 885 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
| 886 |
+
"""Forward pass of the UNet1DConditional model.
|
| 887 |
+
|
| 888 |
+
Args:
|
| 889 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
| 890 |
+
mask (_type_): shape (batch_size, 1, time)
|
| 891 |
+
t (_type_): shape (batch_size)
|
| 892 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
| 893 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
| 894 |
+
|
| 895 |
+
Raises:
|
| 896 |
+
ValueError: _description_
|
| 897 |
+
ValueError: _description_
|
| 898 |
+
|
| 899 |
+
Returns:
|
| 900 |
+
_type_: _description_
|
| 901 |
+
"""
|
| 902 |
+
t = self.time_embeddings(t).to(t.dtype)
|
| 903 |
+
t = self.time_mlp(t)
|
| 904 |
+
|
| 905 |
+
x = pack([x, mu], "b * t")[0]
|
| 906 |
+
|
| 907 |
+
if spks is not None:
|
| 908 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
| 909 |
+
x = pack([x, spks], "b * t")[0]
|
| 910 |
+
if cond is not None:
|
| 911 |
+
x = pack([x, cond], "b * t")[0]
|
| 912 |
+
|
| 913 |
+
hiddens = []
|
| 914 |
+
masks = [mask]
|
| 915 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
| 916 |
+
mask_down = masks[-1]
|
| 917 |
+
x = resnet(x, mask_down, t)
|
| 918 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 919 |
+
if streaming is True:
|
| 920 |
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
| 921 |
+
else:
|
| 922 |
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
| 923 |
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
| 924 |
+
for transformer_block in transformer_blocks:
|
| 925 |
+
x = transformer_block(
|
| 926 |
+
hidden_states=x,
|
| 927 |
+
attention_mask=attn_mask,
|
| 928 |
+
timestep=t,
|
| 929 |
+
)
|
| 930 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 931 |
+
hiddens.append(x) # Save hidden states for skip connections
|
| 932 |
+
x = downsample(x * mask_down)
|
| 933 |
+
masks.append(mask_down[:, :, ::2])
|
| 934 |
+
masks = masks[:-1]
|
| 935 |
+
mask_mid = masks[-1]
|
| 936 |
+
|
| 937 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
| 938 |
+
x = resnet(x, mask_mid, t)
|
| 939 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 940 |
+
if streaming is True:
|
| 941 |
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
| 942 |
+
else:
|
| 943 |
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
| 944 |
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
| 945 |
+
for transformer_block in transformer_blocks:
|
| 946 |
+
x = transformer_block(
|
| 947 |
+
hidden_states=x,
|
| 948 |
+
attention_mask=attn_mask,
|
| 949 |
+
timestep=t,
|
| 950 |
+
)
|
| 951 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 952 |
+
|
| 953 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
| 954 |
+
mask_up = masks.pop()
|
| 955 |
+
skip = hiddens.pop()
|
| 956 |
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
| 957 |
+
x = resnet(x, mask_up, t)
|
| 958 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 959 |
+
if streaming is True:
|
| 960 |
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
| 961 |
+
else:
|
| 962 |
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
| 963 |
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
| 964 |
+
for transformer_block in transformer_blocks:
|
| 965 |
+
x = transformer_block(
|
| 966 |
+
hidden_states=x,
|
| 967 |
+
attention_mask=attn_mask,
|
| 968 |
+
timestep=t,
|
| 969 |
+
)
|
| 970 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 971 |
+
x = upsample(x * mask_up)
|
| 972 |
+
x = self.final_block(x, mask_up)
|
| 973 |
+
output = self.final_proj(x * mask_up)
|
| 974 |
+
return output * mask
|
soulxpodcast/models/modules/flow_components/upsample_encoder.py
ADDED
|
@@ -0,0 +1,998 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def subsequent_chunk_mask(
|
| 10 |
+
size: int,
|
| 11 |
+
chunk_size: int,
|
| 12 |
+
num_left_chunks: int = -1,
|
| 13 |
+
device: torch.device = torch.device("cpu"),
|
| 14 |
+
) -> torch.Tensor:
|
| 15 |
+
"""Create mask for subsequent steps (size, size) with chunk size,
|
| 16 |
+
this is for streaming encoder
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
size (int): size of mask
|
| 20 |
+
chunk_size (int): size of chunk
|
| 21 |
+
num_left_chunks (int): number of left chunks
|
| 22 |
+
<0: use full chunk
|
| 23 |
+
>=0: use num_left_chunks
|
| 24 |
+
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
torch.Tensor: mask
|
| 28 |
+
|
| 29 |
+
Examples:
|
| 30 |
+
>>> subsequent_chunk_mask(4, 2)
|
| 31 |
+
[[1, 1, 0, 0],
|
| 32 |
+
[1, 1, 0, 0],
|
| 33 |
+
[1, 1, 1, 1],
|
| 34 |
+
[1, 1, 1, 1]]
|
| 35 |
+
"""
|
| 36 |
+
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
|
| 37 |
+
pos_idx = torch.arange(size, device=device)
|
| 38 |
+
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
|
| 39 |
+
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
|
| 40 |
+
return ret
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def add_optional_chunk_mask(xs: torch.Tensor,
|
| 44 |
+
masks: torch.Tensor,
|
| 45 |
+
use_dynamic_chunk: bool,
|
| 46 |
+
use_dynamic_left_chunk: bool,
|
| 47 |
+
decoding_chunk_size: int,
|
| 48 |
+
static_chunk_size: int,
|
| 49 |
+
num_decoding_left_chunks: int,
|
| 50 |
+
enable_full_context: bool = True):
|
| 51 |
+
""" Apply optional mask for encoder.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
| 55 |
+
mask (torch.Tensor): mask for xs, (B, 1, L)
|
| 56 |
+
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
| 57 |
+
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
| 58 |
+
training.
|
| 59 |
+
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
| 60 |
+
0: default for training, use random dynamic chunk.
|
| 61 |
+
<0: for decoding, use full chunk.
|
| 62 |
+
>0: for decoding, use fixed chunk size as set.
|
| 63 |
+
static_chunk_size (int): chunk size for static chunk training/decoding
|
| 64 |
+
if it's greater than 0, if use_dynamic_chunk is true,
|
| 65 |
+
this parameter will be ignored
|
| 66 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
| 67 |
+
the chunk size is decoding_chunk_size.
|
| 68 |
+
>=0: use num_decoding_left_chunks
|
| 69 |
+
<0: use all left chunks
|
| 70 |
+
enable_full_context (bool):
|
| 71 |
+
True: chunk size is either [1, 25] or full context(max_len)
|
| 72 |
+
False: chunk size ~ U[1, 25]
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
torch.Tensor: chunk mask of the input xs.
|
| 76 |
+
"""
|
| 77 |
+
# Whether to use chunk mask or not
|
| 78 |
+
if use_dynamic_chunk:
|
| 79 |
+
max_len = xs.size(1)
|
| 80 |
+
if decoding_chunk_size < 0:
|
| 81 |
+
chunk_size = max_len
|
| 82 |
+
num_left_chunks = -1
|
| 83 |
+
elif decoding_chunk_size > 0:
|
| 84 |
+
chunk_size = decoding_chunk_size
|
| 85 |
+
num_left_chunks = num_decoding_left_chunks
|
| 86 |
+
else:
|
| 87 |
+
# chunk size is either [1, 25] or full context(max_len).
|
| 88 |
+
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
| 89 |
+
# delay, the maximum frame is 100 / 4 = 25.
|
| 90 |
+
chunk_size = torch.randint(1, max_len, (1, )).item()
|
| 91 |
+
num_left_chunks = -1
|
| 92 |
+
if chunk_size > max_len // 2 and enable_full_context:
|
| 93 |
+
chunk_size = max_len
|
| 94 |
+
else:
|
| 95 |
+
chunk_size = chunk_size % 25 + 1
|
| 96 |
+
if use_dynamic_left_chunk:
|
| 97 |
+
max_left_chunks = (max_len - 1) // chunk_size
|
| 98 |
+
num_left_chunks = torch.randint(0, max_left_chunks,
|
| 99 |
+
(1, )).item()
|
| 100 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
|
| 101 |
+
num_left_chunks,
|
| 102 |
+
xs.device) # (L, L)
|
| 103 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
| 104 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
| 105 |
+
elif static_chunk_size > 0:
|
| 106 |
+
num_left_chunks = num_decoding_left_chunks
|
| 107 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
|
| 108 |
+
num_left_chunks,
|
| 109 |
+
xs.device) # (L, L)
|
| 110 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
| 111 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
| 112 |
+
else:
|
| 113 |
+
chunk_masks = masks
|
| 114 |
+
assert chunk_masks.dtype == torch.bool
|
| 115 |
+
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
|
| 116 |
+
print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
|
| 117 |
+
chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
|
| 118 |
+
return chunk_masks
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
| 122 |
+
"""Make mask tensor containing indices of padded part.
|
| 123 |
+
|
| 124 |
+
See description of make_non_pad_mask.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
lengths (torch.Tensor): Batch of lengths (B,).
|
| 128 |
+
Returns:
|
| 129 |
+
torch.Tensor: Mask tensor containing indices of padded part.
|
| 130 |
+
|
| 131 |
+
Examples:
|
| 132 |
+
>>> lengths = [5, 3, 2]
|
| 133 |
+
>>> make_pad_mask(lengths)
|
| 134 |
+
masks = [[0, 0, 0, 0 ,0],
|
| 135 |
+
[0, 0, 0, 1, 1],
|
| 136 |
+
[0, 0, 1, 1, 1]]
|
| 137 |
+
"""
|
| 138 |
+
batch_size = lengths.size(0)
|
| 139 |
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
| 140 |
+
seq_range = torch.arange(0,
|
| 141 |
+
max_len,
|
| 142 |
+
dtype=torch.int64,
|
| 143 |
+
device=lengths.device)
|
| 144 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
| 145 |
+
seq_length_expand = lengths.unsqueeze(-1)
|
| 146 |
+
mask = seq_range_expand >= seq_length_expand
|
| 147 |
+
return mask
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class EspnetRelPositionalEncoding(torch.nn.Module):
|
| 151 |
+
"""Relative positional encoding module (new implementation).
|
| 152 |
+
|
| 153 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
| 154 |
+
|
| 155 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
d_model (int): Embedding dimension.
|
| 159 |
+
max_len (int): Maximum input length.
|
| 160 |
+
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
def __init__(self, d_model: int, max_len: int = 5000):
|
| 164 |
+
super(EspnetRelPositionalEncoding, self).__init__()
|
| 165 |
+
self.d_model = d_model
|
| 166 |
+
self.xscale = math.sqrt(self.d_model)
|
| 167 |
+
self.pe = None
|
| 168 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
| 169 |
+
|
| 170 |
+
def extend_pe(self, x: torch.Tensor):
|
| 171 |
+
"""Reset the positional encodings."""
|
| 172 |
+
if self.pe is not None:
|
| 173 |
+
# self.pe contains both positive and negative parts
|
| 174 |
+
# the length of self.pe is 2 * input_len - 1
|
| 175 |
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
| 176 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
| 177 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
| 178 |
+
return
|
| 179 |
+
# Suppose `i` means to the position of query vecotr and `j` means the
|
| 180 |
+
# position of key vector. We use position relative positions when keys
|
| 181 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
| 182 |
+
pe_positive = torch.zeros(x.size(1), self.d_model)
|
| 183 |
+
pe_negative = torch.zeros(x.size(1), self.d_model)
|
| 184 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
| 185 |
+
div_term = torch.exp(
|
| 186 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
| 187 |
+
* -(math.log(10000.0) / self.d_model)
|
| 188 |
+
)
|
| 189 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
| 190 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
| 191 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
| 192 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
| 193 |
+
|
| 194 |
+
# Reserve the order of positive indices and concat both positive and
|
| 195 |
+
# negative indices. This is used to support the shifting trick
|
| 196 |
+
# as in https://arxiv.org/abs/1901.02860
|
| 197 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
| 198 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
| 199 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
| 200 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
| 201 |
+
|
| 202 |
+
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
|
| 203 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
| 204 |
+
"""Add positional encoding.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
| 211 |
+
|
| 212 |
+
"""
|
| 213 |
+
self.extend_pe(x)
|
| 214 |
+
x = x * self.xscale
|
| 215 |
+
pos_emb = self.position_encoding(size=x.size(1), offset=offset)
|
| 216 |
+
return x, pos_emb
|
| 217 |
+
|
| 218 |
+
def position_encoding(self,
|
| 219 |
+
offset: Union[int, torch.Tensor],
|
| 220 |
+
size: int) -> torch.Tensor:
|
| 221 |
+
""" For getting encoding in a streaming fashion
|
| 222 |
+
|
| 223 |
+
Attention!!!!!
|
| 224 |
+
we apply dropout only once at the whole utterance level in a none
|
| 225 |
+
streaming way, but will call this function several times with
|
| 226 |
+
increasing input size in a streaming scenario, so the dropout will
|
| 227 |
+
be applied several times.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
offset (int or torch.tensor): start offset
|
| 231 |
+
size (int): required size of position encoding
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
torch.Tensor: Corresponding encoding
|
| 235 |
+
"""
|
| 236 |
+
# How to subscript a Union type:
|
| 237 |
+
# https://github.com/pytorch/pytorch/issues/69434
|
| 238 |
+
if isinstance(offset, int):
|
| 239 |
+
pos_emb = self.pe[
|
| 240 |
+
:,
|
| 241 |
+
self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
|
| 242 |
+
]
|
| 243 |
+
elif isinstance(offset, torch.Tensor):
|
| 244 |
+
pos_emb = self.pe[
|
| 245 |
+
:,
|
| 246 |
+
self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
|
| 247 |
+
]
|
| 248 |
+
return pos_emb
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class LinearNoSubsampling(torch.nn.Module):
|
| 252 |
+
"""Linear transform the input without subsampling
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
idim (int): Input dimension.
|
| 256 |
+
odim (int): Output dimension.
|
| 257 |
+
pos_enc_class (torch.nn.Module): Positional encoding class.
|
| 258 |
+
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
def __init__(self, idim: int, odim: int,
|
| 262 |
+
pos_enc_class: torch.nn.Module):
|
| 263 |
+
super().__init__()
|
| 264 |
+
self.out = torch.nn.Sequential(
|
| 265 |
+
torch.nn.Linear(idim, odim),
|
| 266 |
+
torch.nn.LayerNorm(odim, eps=1e-5),
|
| 267 |
+
)
|
| 268 |
+
self.pos_enc = pos_enc_class
|
| 269 |
+
self.right_context = 0
|
| 270 |
+
self.subsampling_rate = 1
|
| 271 |
+
|
| 272 |
+
def forward(
|
| 273 |
+
self,
|
| 274 |
+
x: torch.Tensor,
|
| 275 |
+
x_mask: torch.Tensor,
|
| 276 |
+
offset: Union[int, torch.Tensor] = 0
|
| 277 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 278 |
+
"""Input x.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
| 282 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
| 286 |
+
where time' = time .
|
| 287 |
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
| 288 |
+
where time' = time .
|
| 289 |
+
|
| 290 |
+
"""
|
| 291 |
+
x = self.out(x)
|
| 292 |
+
x, pos_emb = self.pos_enc(x, offset)
|
| 293 |
+
return x, pos_emb, x_mask
|
| 294 |
+
|
| 295 |
+
def position_encoding(self, offset: Union[int, torch.Tensor],
|
| 296 |
+
size: int) -> torch.Tensor:
|
| 297 |
+
return self.pos_enc.position_encoding(offset, size)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class Upsample1D(nn.Module):
|
| 301 |
+
"""A 1D upsampling layer with an optional convolution.
|
| 302 |
+
|
| 303 |
+
Parameters:
|
| 304 |
+
channels (`int`):
|
| 305 |
+
number of channels in the inputs and outputs.
|
| 306 |
+
use_conv (`bool`, default `False`):
|
| 307 |
+
option to use a convolution.
|
| 308 |
+
use_conv_transpose (`bool`, default `False`):
|
| 309 |
+
option to use a convolution transpose.
|
| 310 |
+
out_channels (`int`, optional):
|
| 311 |
+
number of output channels. Defaults to `channels`.
|
| 312 |
+
"""
|
| 313 |
+
|
| 314 |
+
def __init__(self, channels: int, out_channels: int, stride: int = 2):
|
| 315 |
+
super().__init__()
|
| 316 |
+
self.channels = channels
|
| 317 |
+
self.out_channels = out_channels
|
| 318 |
+
self.stride = stride
|
| 319 |
+
# In this mode, first repeat interpolate, than conv with stride=1
|
| 320 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
|
| 321 |
+
|
| 322 |
+
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 323 |
+
outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
|
| 324 |
+
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
|
| 325 |
+
outputs = self.conv(outputs)
|
| 326 |
+
return outputs, input_lengths * self.stride
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class PreLookaheadLayer(nn.Module):
|
| 330 |
+
def __init__(self, channels: int, pre_lookahead_len: int = 1):
|
| 331 |
+
super().__init__()
|
| 332 |
+
self.channels = channels
|
| 333 |
+
self.pre_lookahead_len = pre_lookahead_len
|
| 334 |
+
self.conv1 = nn.Conv1d(
|
| 335 |
+
channels, channels,
|
| 336 |
+
kernel_size=pre_lookahead_len + 1,
|
| 337 |
+
stride=1, padding=0,
|
| 338 |
+
)
|
| 339 |
+
self.conv2 = nn.Conv1d(
|
| 340 |
+
channels, channels,
|
| 341 |
+
kernel_size=3, stride=1, padding=0,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0)) -> torch.Tensor:
|
| 345 |
+
"""
|
| 346 |
+
inputs: (batch_size, seq_len, channels)
|
| 347 |
+
"""
|
| 348 |
+
outputs = inputs.transpose(1, 2).contiguous()
|
| 349 |
+
context = context.transpose(1, 2).contiguous()
|
| 350 |
+
# look ahead
|
| 351 |
+
if context.size(2) == 0:
|
| 352 |
+
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
|
| 353 |
+
else:
|
| 354 |
+
assert self.training is False, 'you have passed context, make sure that you are running inference mode'
|
| 355 |
+
assert context.size(2) == self.pre_lookahead_len
|
| 356 |
+
outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0)
|
| 357 |
+
outputs = F.leaky_relu(self.conv1(outputs))
|
| 358 |
+
# outputs
|
| 359 |
+
outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
|
| 360 |
+
outputs = self.conv2(outputs)
|
| 361 |
+
outputs = outputs.transpose(1, 2).contiguous()
|
| 362 |
+
|
| 363 |
+
# residual connection
|
| 364 |
+
outputs = outputs + inputs
|
| 365 |
+
return outputs
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class MultiHeadedAttention(nn.Module):
|
| 369 |
+
"""Multi-Head Attention layer.
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
n_head (int): The number of heads.
|
| 373 |
+
n_feat (int): The number of features.
|
| 374 |
+
dropout_rate (float): Dropout rate.
|
| 375 |
+
key_bias (bool): Whether to use bias in key linear layer.
|
| 376 |
+
|
| 377 |
+
"""
|
| 378 |
+
|
| 379 |
+
def __init__(self,
|
| 380 |
+
n_head: int,
|
| 381 |
+
n_feat: int,
|
| 382 |
+
dropout_rate: float,
|
| 383 |
+
key_bias: bool = True):
|
| 384 |
+
super().__init__()
|
| 385 |
+
assert n_feat % n_head == 0
|
| 386 |
+
# We assume d_v always equals d_k
|
| 387 |
+
self.d_k = n_feat // n_head
|
| 388 |
+
self.h = n_head
|
| 389 |
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
| 390 |
+
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
|
| 391 |
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
| 392 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
| 393 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
| 394 |
+
|
| 395 |
+
def forward_qkv(
|
| 396 |
+
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
| 397 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 398 |
+
"""Transform query, key and value.
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 402 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 403 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 404 |
+
|
| 405 |
+
Returns:
|
| 406 |
+
torch.Tensor: Transformed query tensor, size
|
| 407 |
+
(#batch, n_head, time1, d_k).
|
| 408 |
+
torch.Tensor: Transformed key tensor, size
|
| 409 |
+
(#batch, n_head, time2, d_k).
|
| 410 |
+
torch.Tensor: Transformed value tensor, size
|
| 411 |
+
(#batch, n_head, time2, d_k).
|
| 412 |
+
|
| 413 |
+
"""
|
| 414 |
+
n_batch = query.size(0)
|
| 415 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
| 416 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
| 417 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
| 418 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
| 419 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
| 420 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
| 421 |
+
|
| 422 |
+
return q, k, v
|
| 423 |
+
|
| 424 |
+
def forward_attention(
|
| 425 |
+
self,
|
| 426 |
+
value: torch.Tensor,
|
| 427 |
+
scores: torch.Tensor,
|
| 428 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
|
| 429 |
+
) -> torch.Tensor:
|
| 430 |
+
"""Compute attention context vector.
|
| 431 |
+
|
| 432 |
+
Args:
|
| 433 |
+
value (torch.Tensor): Transformed value, size
|
| 434 |
+
(#batch, n_head, time2, d_k).
|
| 435 |
+
scores (torch.Tensor): Attention score, size
|
| 436 |
+
(#batch, n_head, time1, time2).
|
| 437 |
+
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
|
| 438 |
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
| 439 |
+
|
| 440 |
+
Returns:
|
| 441 |
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
| 442 |
+
weighted by the attention score (#batch, time1, time2).
|
| 443 |
+
|
| 444 |
+
"""
|
| 445 |
+
n_batch = value.size(0)
|
| 446 |
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
|
| 447 |
+
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
|
| 448 |
+
# 1st chunk to ease the onnx export.]
|
| 449 |
+
# 2. pytorch training
|
| 450 |
+
if mask.size(2) > 0: # time2 > 0
|
| 451 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
| 452 |
+
# For last chunk, time2 might be larger than scores.size(-1)
|
| 453 |
+
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
| 454 |
+
scores = scores.masked_fill(mask, -float('inf'))
|
| 455 |
+
attn = torch.softmax(scores, dim=-1).masked_fill(
|
| 456 |
+
mask, 0.0) # (batch, head, time1, time2)
|
| 457 |
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
|
| 458 |
+
# 1. onnx(16/-1, -1/-1, 16/0)
|
| 459 |
+
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
|
| 460 |
+
else:
|
| 461 |
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
| 462 |
+
|
| 463 |
+
p_attn = self.dropout(attn)
|
| 464 |
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
| 465 |
+
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
|
| 466 |
+
self.h * self.d_k)
|
| 467 |
+
) # (batch, time1, d_model)
|
| 468 |
+
|
| 469 |
+
return self.linear_out(x) # (batch, time1, d_model)
|
| 470 |
+
|
| 471 |
+
def forward(
|
| 472 |
+
self,
|
| 473 |
+
query: torch.Tensor,
|
| 474 |
+
key: torch.Tensor,
|
| 475 |
+
value: torch.Tensor,
|
| 476 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
| 477 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
| 478 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
| 479 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 480 |
+
"""Compute scaled dot product attention.
|
| 481 |
+
|
| 482 |
+
Args:
|
| 483 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 484 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 485 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 486 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
| 487 |
+
(#batch, time1, time2).
|
| 488 |
+
1.When applying cross attention between decoder and encoder,
|
| 489 |
+
the batch padding mask for input is in (#batch, 1, T) shape.
|
| 490 |
+
2.When applying self attention of encoder,
|
| 491 |
+
the mask is in (#batch, T, T) shape.
|
| 492 |
+
3.When applying self attention of decoder,
|
| 493 |
+
the mask is in (#batch, L, L) shape.
|
| 494 |
+
4.If the different position in decoder see different block
|
| 495 |
+
of the encoder, such as Mocha, the passed in mask could be
|
| 496 |
+
in (#batch, L, T) shape. But there is no such case in current
|
| 497 |
+
CosyVoice.
|
| 498 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
| 499 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
| 500 |
+
and `head * d_k == size`
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
Returns:
|
| 504 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
| 505 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
| 506 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
| 507 |
+
and `head * d_k == size`
|
| 508 |
+
|
| 509 |
+
"""
|
| 510 |
+
q, k, v = self.forward_qkv(query, key, value)
|
| 511 |
+
|
| 512 |
+
# NOTE(xcsong):
|
| 513 |
+
# when export onnx model, for 1st chunk, we feed
|
| 514 |
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
| 515 |
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
| 516 |
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
| 517 |
+
# and we will always do splitting and
|
| 518 |
+
# concatnation(this will simplify onnx export). Note that
|
| 519 |
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
| 520 |
+
# when export jit model, for 1st chunk, we always feed
|
| 521 |
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
| 522 |
+
# >>> a = torch.ones((1, 2, 0, 4))
|
| 523 |
+
# >>> b = torch.ones((1, 2, 3, 4))
|
| 524 |
+
# >>> c = torch.cat((a, b), dim=2)
|
| 525 |
+
# >>> torch.equal(b, c) # True
|
| 526 |
+
# >>> d = torch.split(a, 2, dim=-1)
|
| 527 |
+
# >>> torch.equal(d[0], d[1]) # True
|
| 528 |
+
if cache.size(0) > 0:
|
| 529 |
+
key_cache, value_cache = torch.split(cache,
|
| 530 |
+
cache.size(-1) // 2,
|
| 531 |
+
dim=-1)
|
| 532 |
+
k = torch.cat([key_cache, k], dim=2)
|
| 533 |
+
v = torch.cat([value_cache, v], dim=2)
|
| 534 |
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
| 535 |
+
# non-trivial to calculate `next_cache_start` here.
|
| 536 |
+
new_cache = torch.cat((k, v), dim=-1)
|
| 537 |
+
|
| 538 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
| 539 |
+
return self.forward_attention(v, scores, mask), new_cache
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
| 543 |
+
"""Multi-Head Attention layer with relative position encoding.
|
| 544 |
+
Paper: https://arxiv.org/abs/1901.02860
|
| 545 |
+
Args:
|
| 546 |
+
n_head (int): The number of heads.
|
| 547 |
+
n_feat (int): The number of features.
|
| 548 |
+
dropout_rate (float): Dropout rate.
|
| 549 |
+
key_bias (bool): Whether to use bias in key linear layer.
|
| 550 |
+
"""
|
| 551 |
+
|
| 552 |
+
def __init__(self,
|
| 553 |
+
n_head: int,
|
| 554 |
+
n_feat: int,
|
| 555 |
+
dropout_rate: float,
|
| 556 |
+
key_bias: bool = True):
|
| 557 |
+
super().__init__(n_head, n_feat, dropout_rate, key_bias)
|
| 558 |
+
# linear transformation for positional encoding
|
| 559 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
| 560 |
+
# these two learnable bias are used in matrix c and matrix d
|
| 561 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 562 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
| 563 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
| 564 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
| 565 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
| 566 |
+
|
| 567 |
+
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
| 568 |
+
"""Compute relative positional encoding.
|
| 569 |
+
|
| 570 |
+
Args:
|
| 571 |
+
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
| 572 |
+
time1 means the length of query vector.
|
| 573 |
+
|
| 574 |
+
Returns:
|
| 575 |
+
torch.Tensor: Output tensor.
|
| 576 |
+
|
| 577 |
+
"""
|
| 578 |
+
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
|
| 579 |
+
device=x.device,
|
| 580 |
+
dtype=x.dtype)
|
| 581 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
| 582 |
+
|
| 583 |
+
x_padded = x_padded.view(x.size()[0],
|
| 584 |
+
x.size()[1],
|
| 585 |
+
x.size(3) + 1, x.size(2))
|
| 586 |
+
x = x_padded[:, :, 1:].view_as(x)[
|
| 587 |
+
:, :, :, : x.size(-1) // 2 + 1
|
| 588 |
+
] # only keep the positions from 0 to time2
|
| 589 |
+
return x
|
| 590 |
+
|
| 591 |
+
def forward(
|
| 592 |
+
self,
|
| 593 |
+
query: torch.Tensor,
|
| 594 |
+
key: torch.Tensor,
|
| 595 |
+
value: torch.Tensor,
|
| 596 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
| 597 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
| 598 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
| 599 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 600 |
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
| 601 |
+
Args:
|
| 602 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 603 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 604 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 605 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
| 606 |
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
| 607 |
+
pos_emb (torch.Tensor): Positional embedding tensor
|
| 608 |
+
(#batch, time2, size).
|
| 609 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
| 610 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
| 611 |
+
and `head * d_k == size`
|
| 612 |
+
Returns:
|
| 613 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
| 614 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
| 615 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
| 616 |
+
and `head * d_k == size`
|
| 617 |
+
"""
|
| 618 |
+
q, k, v = self.forward_qkv(query, key, value)
|
| 619 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
| 620 |
+
|
| 621 |
+
# NOTE(xcsong):
|
| 622 |
+
# when export onnx model, for 1st chunk, we feed
|
| 623 |
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
| 624 |
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
| 625 |
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
| 626 |
+
# and we will always do splitting and
|
| 627 |
+
# concatnation(this will simplify onnx export). Note that
|
| 628 |
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
| 629 |
+
# when export jit model, for 1st chunk, we always feed
|
| 630 |
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
| 631 |
+
# >>> a = torch.ones((1, 2, 0, 4))
|
| 632 |
+
# >>> b = torch.ones((1, 2, 3, 4))
|
| 633 |
+
# >>> c = torch.cat((a, b), dim=2)
|
| 634 |
+
# >>> torch.equal(b, c) # True
|
| 635 |
+
# >>> d = torch.split(a, 2, dim=-1)
|
| 636 |
+
# >>> torch.equal(d[0], d[1]) # True
|
| 637 |
+
if cache.size(0) > 0:
|
| 638 |
+
key_cache, value_cache = torch.split(cache,
|
| 639 |
+
cache.size(-1) // 2,
|
| 640 |
+
dim=-1)
|
| 641 |
+
k = torch.cat([key_cache, k], dim=2)
|
| 642 |
+
v = torch.cat([value_cache, v], dim=2)
|
| 643 |
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
| 644 |
+
# non-trivial to calculate `next_cache_start` here.
|
| 645 |
+
new_cache = torch.cat((k, v), dim=-1)
|
| 646 |
+
|
| 647 |
+
n_batch_pos = pos_emb.size(0)
|
| 648 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
| 649 |
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
| 650 |
+
|
| 651 |
+
# (batch, head, time1, d_k)
|
| 652 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
| 653 |
+
# (batch, head, time1, d_k)
|
| 654 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
| 655 |
+
|
| 656 |
+
# compute attention score
|
| 657 |
+
# first compute matrix a and matrix c
|
| 658 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 659 |
+
# (batch, head, time1, time2)
|
| 660 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
| 661 |
+
|
| 662 |
+
# compute matrix b and matrix d
|
| 663 |
+
# (batch, head, time1, time2)
|
| 664 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
| 665 |
+
# NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
|
| 666 |
+
if matrix_ac.shape != matrix_bd.shape:
|
| 667 |
+
matrix_bd = self.rel_shift(matrix_bd)
|
| 668 |
+
|
| 669 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
| 670 |
+
self.d_k) # (batch, head, time1, time2)
|
| 671 |
+
|
| 672 |
+
return self.forward_attention(v, scores, mask), new_cache
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
class PositionwiseFeedForward(torch.nn.Module):
|
| 676 |
+
"""Positionwise feed forward layer.
|
| 677 |
+
|
| 678 |
+
FeedForward are appied on each position of the sequence.
|
| 679 |
+
The output dim is same with the input dim.
|
| 680 |
+
|
| 681 |
+
Args:
|
| 682 |
+
idim (int): Input dimenstion.
|
| 683 |
+
hidden_units (int): The number of hidden units.
|
| 684 |
+
dropout_rate (float): Dropout rate.
|
| 685 |
+
activation (torch.nn.Module): Activation function
|
| 686 |
+
"""
|
| 687 |
+
|
| 688 |
+
def __init__(
|
| 689 |
+
self,
|
| 690 |
+
idim: int,
|
| 691 |
+
hidden_units: int,
|
| 692 |
+
dropout_rate: float,
|
| 693 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
| 694 |
+
):
|
| 695 |
+
super(PositionwiseFeedForward, self).__init__()
|
| 696 |
+
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
| 697 |
+
self.activation = activation
|
| 698 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
| 699 |
+
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
| 700 |
+
|
| 701 |
+
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
| 702 |
+
"""Forward function.
|
| 703 |
+
|
| 704 |
+
Args:
|
| 705 |
+
xs: input tensor (B, L, D)
|
| 706 |
+
Returns:
|
| 707 |
+
output tensor, (B, L, D)
|
| 708 |
+
"""
|
| 709 |
+
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
class ConformerEncoderLayer(nn.Module):
|
| 713 |
+
"""Encoder layer module.
|
| 714 |
+
Args:
|
| 715 |
+
size (int): Input dimension.
|
| 716 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
| 717 |
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
| 718 |
+
instance can be used as the argument.
|
| 719 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
| 720 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
| 721 |
+
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
|
| 722 |
+
instance.
|
| 723 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
| 724 |
+
conv_module (torch.nn.Module): Convolution module instance.
|
| 725 |
+
`ConvlutionModule` instance can be used as the argument.
|
| 726 |
+
dropout_rate (float): Dropout rate.
|
| 727 |
+
normalize_before (bool):
|
| 728 |
+
True: use layer_norm before each sub-block.
|
| 729 |
+
False: use layer_norm after each sub-block.
|
| 730 |
+
"""
|
| 731 |
+
|
| 732 |
+
def __init__(
|
| 733 |
+
self,
|
| 734 |
+
size: int,
|
| 735 |
+
self_attn: torch.nn.Module,
|
| 736 |
+
feed_forward: Optional[nn.Module] = None,
|
| 737 |
+
feed_forward_macaron: Optional[nn.Module] = None,
|
| 738 |
+
conv_module: Optional[nn.Module] = None,
|
| 739 |
+
dropout_rate: float = 0.0,
|
| 740 |
+
normalize_before: bool = True,
|
| 741 |
+
):
|
| 742 |
+
super().__init__()
|
| 743 |
+
self.self_attn = self_attn
|
| 744 |
+
self.feed_forward = feed_forward
|
| 745 |
+
self.feed_forward_macaron = feed_forward_macaron
|
| 746 |
+
self.conv_module = conv_module
|
| 747 |
+
self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
|
| 748 |
+
self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
|
| 749 |
+
if feed_forward_macaron is not None:
|
| 750 |
+
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
|
| 751 |
+
self.ff_scale = 0.5
|
| 752 |
+
else:
|
| 753 |
+
self.ff_scale = 1.0
|
| 754 |
+
if self.conv_module is not None:
|
| 755 |
+
self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
|
| 756 |
+
self.norm_final = nn.LayerNorm(
|
| 757 |
+
size, eps=1e-12) # for the final output of the block
|
| 758 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 759 |
+
self.size = size
|
| 760 |
+
self.normalize_before = normalize_before
|
| 761 |
+
|
| 762 |
+
def forward(
|
| 763 |
+
self,
|
| 764 |
+
x: torch.Tensor,
|
| 765 |
+
mask: torch.Tensor,
|
| 766 |
+
pos_emb: torch.Tensor,
|
| 767 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
| 768 |
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
| 769 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
| 770 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 771 |
+
"""Compute encoded features.
|
| 772 |
+
|
| 773 |
+
Args:
|
| 774 |
+
x (torch.Tensor): (#batch, time, size)
|
| 775 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
| 776 |
+
(0, 0, 0) means fake mask.
|
| 777 |
+
pos_emb (torch.Tensor): positional encoding, must not be None
|
| 778 |
+
for ConformerEncoderLayer.
|
| 779 |
+
mask_pad (torch.Tensor): batch padding mask used for conv module.
|
| 780 |
+
(#batch, 1,time), (0, 0, 0) means fake mask.
|
| 781 |
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
| 782 |
+
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
| 783 |
+
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
| 784 |
+
(#batch=1, size, cache_t2)
|
| 785 |
+
Returns:
|
| 786 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
| 787 |
+
torch.Tensor: Mask tensor (#batch, time, time).
|
| 788 |
+
torch.Tensor: att_cache tensor,
|
| 789 |
+
(#batch=1, head, cache_t1 + time, d_k * 2).
|
| 790 |
+
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
|
| 791 |
+
"""
|
| 792 |
+
|
| 793 |
+
# whether to use macaron style
|
| 794 |
+
if self.feed_forward_macaron is not None:
|
| 795 |
+
residual = x
|
| 796 |
+
if self.normalize_before:
|
| 797 |
+
x = self.norm_ff_macaron(x)
|
| 798 |
+
x = residual + self.ff_scale * self.dropout(
|
| 799 |
+
self.feed_forward_macaron(x))
|
| 800 |
+
if not self.normalize_before:
|
| 801 |
+
x = self.norm_ff_macaron(x)
|
| 802 |
+
|
| 803 |
+
# multi-headed self-attention module
|
| 804 |
+
residual = x
|
| 805 |
+
if self.normalize_before:
|
| 806 |
+
x = self.norm_mha(x)
|
| 807 |
+
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
|
| 808 |
+
att_cache)
|
| 809 |
+
x = residual + self.dropout(x_att)
|
| 810 |
+
if not self.normalize_before:
|
| 811 |
+
x = self.norm_mha(x)
|
| 812 |
+
|
| 813 |
+
# convolution module
|
| 814 |
+
# Fake new cnn cache here, and then change it in conv_module
|
| 815 |
+
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
| 816 |
+
if self.conv_module is not None:
|
| 817 |
+
residual = x
|
| 818 |
+
if self.normalize_before:
|
| 819 |
+
x = self.norm_conv(x)
|
| 820 |
+
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
|
| 821 |
+
x = residual + self.dropout(x)
|
| 822 |
+
|
| 823 |
+
if not self.normalize_before:
|
| 824 |
+
x = self.norm_conv(x)
|
| 825 |
+
|
| 826 |
+
# feed forward module
|
| 827 |
+
residual = x
|
| 828 |
+
if self.normalize_before:
|
| 829 |
+
x = self.norm_ff(x)
|
| 830 |
+
|
| 831 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
| 832 |
+
if not self.normalize_before:
|
| 833 |
+
x = self.norm_ff(x)
|
| 834 |
+
|
| 835 |
+
if self.conv_module is not None:
|
| 836 |
+
x = self.norm_final(x)
|
| 837 |
+
|
| 838 |
+
return x, mask, new_att_cache, new_cnn_cache
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
class UpsampleConformerEncoder(torch.nn.Module):
|
| 842 |
+
"""
|
| 843 |
+
Args:
|
| 844 |
+
input_size (int): input dim
|
| 845 |
+
output_size (int): dimension of attention
|
| 846 |
+
attention_heads (int): the number of heads of multi head attention
|
| 847 |
+
linear_units (int): the hidden units number of position-wise feed
|
| 848 |
+
forward
|
| 849 |
+
num_blocks (int): the number of decoder blocks
|
| 850 |
+
static_chunk_size (int): chunk size for static chunk training and
|
| 851 |
+
decoding
|
| 852 |
+
use_dynamic_chunk (bool): whether use dynamic chunk size for
|
| 853 |
+
training or not, You can only use fixed chunk(chunk_size > 0)
|
| 854 |
+
or dyanmic chunk size(use_dynamic_chunk = True)
|
| 855 |
+
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
|
| 856 |
+
dynamic chunk training
|
| 857 |
+
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
| 858 |
+
"""
|
| 859 |
+
|
| 860 |
+
def __init__(
|
| 861 |
+
self,
|
| 862 |
+
input_size: int = 512,
|
| 863 |
+
output_size: int = 512,
|
| 864 |
+
attention_heads: int = 8,
|
| 865 |
+
linear_units: int = 2048,
|
| 866 |
+
num_blocks: int = 6,
|
| 867 |
+
static_chunk_size: int = 25,
|
| 868 |
+
use_dynamic_chunk: bool = False,
|
| 869 |
+
use_dynamic_left_chunk: bool = False,
|
| 870 |
+
key_bias: bool = True,
|
| 871 |
+
):
|
| 872 |
+
super().__init__()
|
| 873 |
+
self._output_size = output_size
|
| 874 |
+
|
| 875 |
+
self.embed = LinearNoSubsampling(
|
| 876 |
+
input_size, output_size,
|
| 877 |
+
EspnetRelPositionalEncoding(output_size),
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
|
| 881 |
+
self.static_chunk_size = static_chunk_size
|
| 882 |
+
self.use_dynamic_chunk = use_dynamic_chunk
|
| 883 |
+
self.use_dynamic_left_chunk = use_dynamic_left_chunk
|
| 884 |
+
activation = torch.nn.SiLU()
|
| 885 |
+
# self-attention module definition
|
| 886 |
+
encoder_selfattn_layer_args = (
|
| 887 |
+
attention_heads,
|
| 888 |
+
output_size,
|
| 889 |
+
0.0,
|
| 890 |
+
key_bias,
|
| 891 |
+
)
|
| 892 |
+
# feed-forward module definition
|
| 893 |
+
positionwise_layer_args = (
|
| 894 |
+
output_size,
|
| 895 |
+
linear_units,
|
| 896 |
+
0.0,
|
| 897 |
+
activation,
|
| 898 |
+
)
|
| 899 |
+
# convolution module definition
|
| 900 |
+
self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
|
| 901 |
+
self.encoders = torch.nn.ModuleList([
|
| 902 |
+
ConformerEncoderLayer(
|
| 903 |
+
output_size,
|
| 904 |
+
RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args),
|
| 905 |
+
PositionwiseFeedForward(*positionwise_layer_args),
|
| 906 |
+
) for _ in range(num_blocks)
|
| 907 |
+
])
|
| 908 |
+
self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
|
| 909 |
+
self.up_embed = LinearNoSubsampling(
|
| 910 |
+
input_size, output_size,
|
| 911 |
+
EspnetRelPositionalEncoding(output_size),
|
| 912 |
+
)
|
| 913 |
+
self.up_encoders = torch.nn.ModuleList([
|
| 914 |
+
ConformerEncoderLayer(
|
| 915 |
+
output_size,
|
| 916 |
+
RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args),
|
| 917 |
+
PositionwiseFeedForward(*positionwise_layer_args),
|
| 918 |
+
) for _ in range(4)
|
| 919 |
+
])
|
| 920 |
+
|
| 921 |
+
def output_size(self) -> int:
|
| 922 |
+
return self._output_size
|
| 923 |
+
|
| 924 |
+
def forward(
|
| 925 |
+
self,
|
| 926 |
+
xs: torch.Tensor,
|
| 927 |
+
xs_lens: torch.Tensor,
|
| 928 |
+
context: torch.Tensor = torch.zeros(0, 0, 0),
|
| 929 |
+
decoding_chunk_size: int = 0,
|
| 930 |
+
num_decoding_left_chunks: int = -1,
|
| 931 |
+
streaming: bool = False,
|
| 932 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 933 |
+
"""Embed positions in tensor.
|
| 934 |
+
|
| 935 |
+
Args:
|
| 936 |
+
xs: padded input tensor (B, T, D)
|
| 937 |
+
xs_lens: input length (B)
|
| 938 |
+
decoding_chunk_size: decoding chunk size for dynamic chunk
|
| 939 |
+
0: default for training, use random dynamic chunk.
|
| 940 |
+
<0: for decoding, use full chunk.
|
| 941 |
+
>0: for decoding, use fixed chunk size as set.
|
| 942 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
| 943 |
+
the chunk size is decoding_chunk_size.
|
| 944 |
+
>=0: use num_decoding_left_chunks
|
| 945 |
+
<0: use all left chunks
|
| 946 |
+
Returns:
|
| 947 |
+
encoder output tensor xs, and subsampled masks
|
| 948 |
+
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
| 949 |
+
masks: torch.Tensor batch padding mask after subsample
|
| 950 |
+
(B, 1, T' ~= T/subsample_rate)
|
| 951 |
+
NOTE(xcsong):
|
| 952 |
+
We pass the `__call__` method of the modules instead of `forward` to the
|
| 953 |
+
checkpointing API because `__call__` attaches all the hooks of the module.
|
| 954 |
+
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
| 955 |
+
"""
|
| 956 |
+
T = xs.size(1)
|
| 957 |
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
| 958 |
+
xs, pos_emb, masks = self.embed(xs, masks)
|
| 959 |
+
if context.size(1) != 0:
|
| 960 |
+
assert self.training is False, 'you have passed context, make sure that you are running inference mode'
|
| 961 |
+
context_masks = torch.ones(1, 1, context.size(1)).to(masks)
|
| 962 |
+
context, _, _ = self.embed(context, context_masks, offset=xs.size(1))
|
| 963 |
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
| 964 |
+
chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
|
| 965 |
+
# lookahead + conformer encoder
|
| 966 |
+
xs = self.pre_lookahead_layer(xs, context=context)
|
| 967 |
+
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
| 968 |
+
|
| 969 |
+
# upsample + conformer encoder
|
| 970 |
+
xs = xs.transpose(1, 2).contiguous()
|
| 971 |
+
xs, xs_lens = self.up_layer(xs, xs_lens)
|
| 972 |
+
xs = xs.transpose(1, 2).contiguous()
|
| 973 |
+
T = xs.size(1)
|
| 974 |
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
| 975 |
+
xs, pos_emb, masks = self.up_embed(xs, masks)
|
| 976 |
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
| 977 |
+
chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size * self.up_layer.stride if streaming is True else 0, -1)
|
| 978 |
+
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
|
| 979 |
+
|
| 980 |
+
xs = self.after_norm(xs)
|
| 981 |
+
# Here we assume the mask is not changed in encoder layers, so just
|
| 982 |
+
# return the masks before encoder layers, and the masks will be used
|
| 983 |
+
# for cross attention with decoder later
|
| 984 |
+
return xs, masks
|
| 985 |
+
|
| 986 |
+
def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
| 987 |
+
pos_emb: torch.Tensor,
|
| 988 |
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
| 989 |
+
for layer in self.encoders:
|
| 990 |
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
| 991 |
+
return xs
|
| 992 |
+
|
| 993 |
+
def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
| 994 |
+
pos_emb: torch.Tensor,
|
| 995 |
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
| 996 |
+
for layer in self.up_encoders:
|
| 997 |
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
| 998 |
+
return xs
|
soulxpodcast/models/modules/hifigan.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""HIFI-GAN"""
|
| 16 |
+
|
| 17 |
+
from typing import Dict, List
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from scipy.signal import get_window
|
| 24 |
+
from torch.nn import Conv1d, ConvTranspose1d
|
| 25 |
+
from torch.nn.utils import remove_weight_norm
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 29 |
+
except ImportError:
|
| 30 |
+
from torch.nn.utils import weight_norm # noqa
|
| 31 |
+
|
| 32 |
+
from soulxpodcast.models.modules.hifigan_components.layers import (
|
| 33 |
+
ResBlock, SourceModuleHnNSF, SourceModuleHnNSF2, init_weights)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ConvRNNF0Predictor(nn.Module):
|
| 37 |
+
def __init__(self,
|
| 38 |
+
num_class: int = 1,
|
| 39 |
+
in_channels: int = 80,
|
| 40 |
+
cond_channels: int = 512
|
| 41 |
+
):
|
| 42 |
+
super().__init__()
|
| 43 |
+
|
| 44 |
+
self.num_class = num_class
|
| 45 |
+
self.condnet = nn.Sequential(
|
| 46 |
+
weight_norm( # noqa
|
| 47 |
+
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
|
| 48 |
+
),
|
| 49 |
+
nn.ELU(),
|
| 50 |
+
weight_norm( # noqa
|
| 51 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
| 52 |
+
),
|
| 53 |
+
nn.ELU(),
|
| 54 |
+
weight_norm( # noqa
|
| 55 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
| 56 |
+
),
|
| 57 |
+
nn.ELU(),
|
| 58 |
+
weight_norm( # noqa
|
| 59 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
| 60 |
+
),
|
| 61 |
+
nn.ELU(),
|
| 62 |
+
weight_norm( # noqa
|
| 63 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
| 64 |
+
),
|
| 65 |
+
nn.ELU(),
|
| 66 |
+
)
|
| 67 |
+
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
|
| 68 |
+
|
| 69 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 70 |
+
x = self.condnet(x)
|
| 71 |
+
x = x.transpose(1, 2)
|
| 72 |
+
return torch.abs(self.classifier(x).squeeze(-1))
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class HiFTGenerator(nn.Module):
|
| 76 |
+
"""
|
| 77 |
+
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
| 78 |
+
https://arxiv.org/abs/2309.09493
|
| 79 |
+
"""
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
in_channels: int = 80,
|
| 83 |
+
base_channels: int = 512,
|
| 84 |
+
nb_harmonics: int = 8,
|
| 85 |
+
sampling_rate: int = 24000,
|
| 86 |
+
nsf_alpha: float = 0.1,
|
| 87 |
+
nsf_sigma: float = 0.003,
|
| 88 |
+
nsf_voiced_threshold: float = 10,
|
| 89 |
+
upsample_rates: List[int] = [8, 5, 3], # noqa
|
| 90 |
+
upsample_kernel_sizes: List[int] = [16, 11, 7], # noqa
|
| 91 |
+
istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4}, # noqa
|
| 92 |
+
resblock_kernel_sizes: List[int] = [3, 7, 11], # noqa
|
| 93 |
+
resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], # noqa
|
| 94 |
+
source_resblock_kernel_sizes: List[int] = [7, 7, 11], # noqa
|
| 95 |
+
source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], # noqa
|
| 96 |
+
lrelu_slope: float = 0.1,
|
| 97 |
+
audio_limit: float = 0.99,
|
| 98 |
+
f0_predictor: torch.nn.Module = None,
|
| 99 |
+
):
|
| 100 |
+
super(HiFTGenerator, self).__init__()
|
| 101 |
+
|
| 102 |
+
self.out_channels = 1
|
| 103 |
+
self.nb_harmonics = nb_harmonics
|
| 104 |
+
self.sampling_rate = sampling_rate
|
| 105 |
+
self.istft_params = istft_params
|
| 106 |
+
self.lrelu_slope = lrelu_slope
|
| 107 |
+
self.audio_limit = audio_limit
|
| 108 |
+
|
| 109 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 110 |
+
self.num_upsamples = len(upsample_rates)
|
| 111 |
+
# NOTE in CosyVoice2, we use the original SourceModuleHnNSF implementation
|
| 112 |
+
this_SourceModuleHnNSF = SourceModuleHnNSF if self.sampling_rate == 22050 else SourceModuleHnNSF2
|
| 113 |
+
self.m_source = this_SourceModuleHnNSF(
|
| 114 |
+
sampling_rate=sampling_rate,
|
| 115 |
+
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
| 116 |
+
harmonic_num=nb_harmonics,
|
| 117 |
+
sine_amp=nsf_alpha,
|
| 118 |
+
add_noise_std=nsf_sigma,
|
| 119 |
+
voiced_threshod=nsf_voiced_threshold)
|
| 120 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
|
| 121 |
+
|
| 122 |
+
self.conv_pre = weight_norm( # noqa
|
| 123 |
+
Conv1d(in_channels, base_channels, 7, 1, padding=3)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Up
|
| 127 |
+
self.ups = nn.ModuleList()
|
| 128 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 129 |
+
self.ups.append(
|
| 130 |
+
weight_norm( # noqa
|
| 131 |
+
ConvTranspose1d(
|
| 132 |
+
base_channels // (2**i),
|
| 133 |
+
base_channels // (2**(i + 1)),
|
| 134 |
+
k,
|
| 135 |
+
u,
|
| 136 |
+
padding=(k - u) // 2,
|
| 137 |
+
)
|
| 138 |
+
)
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Down
|
| 142 |
+
self.source_downs = nn.ModuleList()
|
| 143 |
+
self.source_resblocks = nn.ModuleList()
|
| 144 |
+
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
| 145 |
+
downsample_cum_rates = np.cumprod(downsample_rates)
|
| 146 |
+
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
|
| 147 |
+
if u == 1:
|
| 148 |
+
self.source_downs.append(
|
| 149 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
|
| 150 |
+
)
|
| 151 |
+
else:
|
| 152 |
+
self.source_downs.append(
|
| 153 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
self.source_resblocks.append(
|
| 157 |
+
ResBlock(base_channels // (2 ** (i + 1)), k, d)
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
self.resblocks = nn.ModuleList()
|
| 161 |
+
for i in range(len(self.ups)):
|
| 162 |
+
ch = base_channels // (2**(i + 1))
|
| 163 |
+
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
| 164 |
+
self.resblocks.append(ResBlock(ch, k, d))
|
| 165 |
+
|
| 166 |
+
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3)) # noqa
|
| 167 |
+
self.ups.apply(init_weights)
|
| 168 |
+
self.conv_post.apply(init_weights)
|
| 169 |
+
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
| 170 |
+
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
| 171 |
+
self.f0_predictor = ConvRNNF0Predictor() if f0_predictor is None else f0_predictor
|
| 172 |
+
|
| 173 |
+
def remove_weight_norm(self):
|
| 174 |
+
print('Removing weight norm...')
|
| 175 |
+
for up in self.ups:
|
| 176 |
+
remove_weight_norm(up)
|
| 177 |
+
for resblock in self.resblocks:
|
| 178 |
+
resblock.remove_weight_norm()
|
| 179 |
+
remove_weight_norm(self.conv_pre)
|
| 180 |
+
remove_weight_norm(self.conv_post)
|
| 181 |
+
self.m_source.remove_weight_norm()
|
| 182 |
+
for source_down in self.source_downs:
|
| 183 |
+
remove_weight_norm(source_down)
|
| 184 |
+
for source_resblock in self.source_resblocks:
|
| 185 |
+
source_resblock.remove_weight_norm()
|
| 186 |
+
|
| 187 |
+
def _stft(self, x):
|
| 188 |
+
spec = torch.stft(
|
| 189 |
+
x,
|
| 190 |
+
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
|
| 191 |
+
return_complex=True)
|
| 192 |
+
spec = torch.view_as_real(spec) # [B, F, TT, 2]
|
| 193 |
+
return spec[..., 0], spec[..., 1]
|
| 194 |
+
|
| 195 |
+
def _istft(self, magnitude, phase):
|
| 196 |
+
magnitude = torch.clip(magnitude, max=1e2)
|
| 197 |
+
real = magnitude * torch.cos(phase)
|
| 198 |
+
img = magnitude * torch.sin(phase)
|
| 199 |
+
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
|
| 200 |
+
self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
| 201 |
+
return inverse_transform
|
| 202 |
+
|
| 203 |
+
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
| 204 |
+
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
| 205 |
+
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
| 206 |
+
|
| 207 |
+
x = self.conv_pre(x)
|
| 208 |
+
for i in range(self.num_upsamples):
|
| 209 |
+
x = F.leaky_relu(x, self.lrelu_slope)
|
| 210 |
+
x = self.ups[i](x)
|
| 211 |
+
|
| 212 |
+
if i == self.num_upsamples - 1:
|
| 213 |
+
x = self.reflection_pad(x)
|
| 214 |
+
|
| 215 |
+
# fusion
|
| 216 |
+
si = self.source_downs[i](s_stft)
|
| 217 |
+
si = self.source_resblocks[i](si)
|
| 218 |
+
x = x + si
|
| 219 |
+
|
| 220 |
+
xs = None
|
| 221 |
+
for j in range(self.num_kernels):
|
| 222 |
+
if xs is None:
|
| 223 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 224 |
+
else:
|
| 225 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 226 |
+
x = xs / self.num_kernels
|
| 227 |
+
|
| 228 |
+
x = F.leaky_relu(x)
|
| 229 |
+
x = self.conv_post(x)
|
| 230 |
+
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
|
| 231 |
+
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
|
| 232 |
+
|
| 233 |
+
x = self._istft(magnitude, phase)
|
| 234 |
+
x = torch.clamp(x*0.98, -self.audio_limit, self.audio_limit)
|
| 235 |
+
return x
|
| 236 |
+
|
| 237 |
+
@torch.inference_mode()
|
| 238 |
+
def forward(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
| 239 |
+
# mel->f0
|
| 240 |
+
f0 = self.f0_predictor(speech_feat)
|
| 241 |
+
# f0->source
|
| 242 |
+
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
| 243 |
+
s, _, _ = self.m_source(s)
|
| 244 |
+
s = s.transpose(1, 2)
|
| 245 |
+
# use cache_source to avoid glitch
|
| 246 |
+
if cache_source.shape[2] != 0:
|
| 247 |
+
s[:, :, :cache_source.shape[2]] = cache_source
|
| 248 |
+
generated_speech = self.decode(x=speech_feat, s=s)
|
| 249 |
+
return generated_speech, s
|
soulxpodcast/models/modules/hifigan_components/__init__.py
ADDED
|
File without changes
|
soulxpodcast/models/modules/hifigan_components/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (211 Bytes). View file
|
|
|
soulxpodcast/models/modules/hifigan_components/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (199 Bytes). View file
|
|
|
soulxpodcast/models/modules/hifigan_components/__pycache__/layers.cpython-311.pyc
ADDED
|
Binary file (20.5 kB). View file
|
|
|
soulxpodcast/models/modules/hifigan_components/__pycache__/layers.cpython-312.pyc
ADDED
|
Binary file (18.7 kB). View file
|
|
|
soulxpodcast/models/modules/hifigan_components/layers.py
ADDED
|
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch.distributions.uniform import Uniform
|
| 7 |
+
from torch.nn import Conv1d
|
| 8 |
+
from torch.nn.utils import remove_weight_norm
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 12 |
+
except ImportError:
|
| 13 |
+
from torch.nn.utils import weight_norm # noqa
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_padding(kernel_size, dilation=1):
|
| 17 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def init_weights(m, mean=0.0, std=0.01):
|
| 21 |
+
classname = m.__class__.__name__
|
| 22 |
+
if classname.find("Conv") != -1:
|
| 23 |
+
m.weight.data.normal_(mean, std)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
"""hifigan based generator implementation.
|
| 27 |
+
|
| 28 |
+
This code is modified from https://github.com/jik876/hifi-gan
|
| 29 |
+
,https://github.com/kan-bayashi/ParallelWaveGAN and
|
| 30 |
+
https://github.com/NVIDIA/BigVGAN
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
| 36 |
+
# LICENSE is in incl_licenses directory.
|
| 37 |
+
class Snake(nn.Module):
|
| 38 |
+
'''
|
| 39 |
+
Implementation of a sine-based periodic activation function
|
| 40 |
+
Shape:
|
| 41 |
+
- Input: (B, C, T)
|
| 42 |
+
- Output: (B, C, T), same shape as the input
|
| 43 |
+
Parameters:
|
| 44 |
+
- alpha - trainable parameter
|
| 45 |
+
References:
|
| 46 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 47 |
+
https://arxiv.org/abs/2006.08195
|
| 48 |
+
Examples:
|
| 49 |
+
>>> a1 = snake(256)
|
| 50 |
+
>>> x = torch.randn(256)
|
| 51 |
+
>>> x = a1(x)
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
in_features: shape of the input
|
| 55 |
+
alpha: trainable parameter
|
| 56 |
+
alpha_trainable: whether alpha is trainable
|
| 57 |
+
alpha_logscale: whether to use log scale for alpha
|
| 58 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
| 59 |
+
alpha will be trained along with the rest of your model.
|
| 60 |
+
'''
|
| 61 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
| 62 |
+
super(Snake, self).__init__()
|
| 63 |
+
self.in_features = in_features
|
| 64 |
+
|
| 65 |
+
# initialize alpha
|
| 66 |
+
self.alpha_logscale = alpha_logscale
|
| 67 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 68 |
+
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
| 69 |
+
else: # linear scale alphas initialized to ones
|
| 70 |
+
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
| 71 |
+
|
| 72 |
+
self.alpha.requires_grad = alpha_trainable
|
| 73 |
+
|
| 74 |
+
self.no_div_by_zero = 0.000000001
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
'''
|
| 78 |
+
Forward pass of the function.
|
| 79 |
+
Applies the function to the input elementwise.
|
| 80 |
+
Snake ∶= x + 1/a * sin^2 (xa)
|
| 81 |
+
'''
|
| 82 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 83 |
+
if self.alpha_logscale:
|
| 84 |
+
alpha = torch.exp(alpha)
|
| 85 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
|
| 86 |
+
|
| 87 |
+
return x
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class ResBlock(torch.nn.Module):
|
| 91 |
+
"""Residual block module in HiFiGAN/BigVGAN."""
|
| 92 |
+
def __init__(
|
| 93 |
+
self,
|
| 94 |
+
channels: int = 512,
|
| 95 |
+
kernel_size: int = 3,
|
| 96 |
+
dilations: List[int] = [1, 3, 5], # noqa
|
| 97 |
+
):
|
| 98 |
+
super(ResBlock, self).__init__()
|
| 99 |
+
self.convs1 = nn.ModuleList()
|
| 100 |
+
self.convs2 = nn.ModuleList()
|
| 101 |
+
|
| 102 |
+
for dilation in dilations:
|
| 103 |
+
self.convs1.append(
|
| 104 |
+
weight_norm( # noqa
|
| 105 |
+
Conv1d(
|
| 106 |
+
channels,
|
| 107 |
+
channels,
|
| 108 |
+
kernel_size,
|
| 109 |
+
1,
|
| 110 |
+
dilation=dilation,
|
| 111 |
+
padding=get_padding(kernel_size, dilation)
|
| 112 |
+
)
|
| 113 |
+
)
|
| 114 |
+
)
|
| 115 |
+
self.convs2.append(
|
| 116 |
+
weight_norm( # noqa
|
| 117 |
+
Conv1d(
|
| 118 |
+
channels,
|
| 119 |
+
channels,
|
| 120 |
+
kernel_size,
|
| 121 |
+
1,
|
| 122 |
+
dilation=1,
|
| 123 |
+
padding=get_padding(kernel_size, 1)
|
| 124 |
+
)
|
| 125 |
+
)
|
| 126 |
+
)
|
| 127 |
+
self.convs1.apply(init_weights)
|
| 128 |
+
self.convs2.apply(init_weights)
|
| 129 |
+
self.activations1 = nn.ModuleList([
|
| 130 |
+
Snake(channels, alpha_logscale=False)
|
| 131 |
+
for _ in range(len(self.convs1))
|
| 132 |
+
])
|
| 133 |
+
self.activations2 = nn.ModuleList([
|
| 134 |
+
Snake(channels, alpha_logscale=False)
|
| 135 |
+
for _ in range(len(self.convs2))
|
| 136 |
+
])
|
| 137 |
+
|
| 138 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 139 |
+
for idx in range(len(self.convs1)):
|
| 140 |
+
xt = self.activations1[idx](x)
|
| 141 |
+
xt = self.convs1[idx](xt)
|
| 142 |
+
xt = self.activations2[idx](xt)
|
| 143 |
+
xt = self.convs2[idx](xt)
|
| 144 |
+
x = xt + x
|
| 145 |
+
return x
|
| 146 |
+
|
| 147 |
+
def remove_weight_norm(self):
|
| 148 |
+
for idx in range(len(self.convs1)):
|
| 149 |
+
remove_weight_norm(self.convs1[idx])
|
| 150 |
+
remove_weight_norm(self.convs2[idx])
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class SineGen(torch.nn.Module):
|
| 154 |
+
""" Definition of sine generator
|
| 155 |
+
SineGen(samp_rate, harmonic_num = 0,
|
| 156 |
+
sine_amp = 0.1, noise_std = 0.003,
|
| 157 |
+
voiced_threshold = 0,
|
| 158 |
+
flag_for_pulse=False)
|
| 159 |
+
samp_rate: sampling rate in Hz
|
| 160 |
+
harmonic_num: number of harmonic overtones (default 0)
|
| 161 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
| 162 |
+
noise_std: std of Gaussian noise (default 0.003)
|
| 163 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
| 164 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
| 165 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
| 166 |
+
segment is always sin(np.pi) or cos(0)
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
def __init__(self, samp_rate, harmonic_num=0,
|
| 170 |
+
sine_amp=0.1, noise_std=0.003,
|
| 171 |
+
voiced_threshold=0):
|
| 172 |
+
super(SineGen, self).__init__()
|
| 173 |
+
self.sine_amp = sine_amp
|
| 174 |
+
self.noise_std = noise_std
|
| 175 |
+
self.harmonic_num = harmonic_num
|
| 176 |
+
self.sampling_rate = samp_rate
|
| 177 |
+
self.voiced_threshold = voiced_threshold
|
| 178 |
+
|
| 179 |
+
def _f02uv(self, f0):
|
| 180 |
+
# generate uv signal
|
| 181 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
| 182 |
+
return uv
|
| 183 |
+
|
| 184 |
+
@torch.no_grad()
|
| 185 |
+
def forward(self, f0):
|
| 186 |
+
"""
|
| 187 |
+
:param f0: [B, 1, sample_len], Hz
|
| 188 |
+
:return: [B, 1, sample_len]
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
|
| 192 |
+
for i in range(self.harmonic_num + 1):
|
| 193 |
+
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
|
| 194 |
+
|
| 195 |
+
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
|
| 196 |
+
u_dist = Uniform(low=-np.pi, high=np.pi)
|
| 197 |
+
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
|
| 198 |
+
phase_vec[:, 0, :] = 0
|
| 199 |
+
|
| 200 |
+
# generate sine waveforms
|
| 201 |
+
sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
|
| 202 |
+
|
| 203 |
+
# generate uv signal
|
| 204 |
+
uv = self._f02uv(f0)
|
| 205 |
+
|
| 206 |
+
# noise: for unvoiced should be similar to sine_amp
|
| 207 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
| 208 |
+
# . for voiced regions is self.noise_std
|
| 209 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
| 210 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
| 211 |
+
|
| 212 |
+
# first: set the unvoiced part to 0 by uv
|
| 213 |
+
# then: additive noise
|
| 214 |
+
sine_waves = sine_waves * uv + noise
|
| 215 |
+
return sine_waves, uv, noise
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
| 219 |
+
""" SourceModule for hn-nsf
|
| 220 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
| 221 |
+
add_noise_std=0.003, voiced_threshod=0)
|
| 222 |
+
sampling_rate: sampling_rate in Hz
|
| 223 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
| 224 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
| 225 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
| 226 |
+
note that amplitude of noise in unvoiced is decided
|
| 227 |
+
by sine_amp
|
| 228 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
| 229 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
| 230 |
+
F0_sampled (batchsize, length, 1)
|
| 231 |
+
Sine_source (batchsize, length, 1)
|
| 232 |
+
noise_source (batchsize, length 1)
|
| 233 |
+
uv (batchsize, length, 1)
|
| 234 |
+
"""
|
| 235 |
+
|
| 236 |
+
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
| 237 |
+
add_noise_std=0.003, voiced_threshod=0):
|
| 238 |
+
super(SourceModuleHnNSF, self).__init__()
|
| 239 |
+
|
| 240 |
+
self.sine_amp = sine_amp
|
| 241 |
+
self.noise_std = add_noise_std
|
| 242 |
+
|
| 243 |
+
# to produce sine waveforms
|
| 244 |
+
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
|
| 245 |
+
sine_amp, add_noise_std, voiced_threshod)
|
| 246 |
+
|
| 247 |
+
# to merge source harmonics into a single excitation
|
| 248 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
| 249 |
+
self.l_tanh = torch.nn.Tanh()
|
| 250 |
+
|
| 251 |
+
def forward(self, x):
|
| 252 |
+
"""
|
| 253 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
| 254 |
+
F0_sampled (batchsize, length, 1)
|
| 255 |
+
Sine_source (batchsize, length, 1)
|
| 256 |
+
noise_source (batchsize, length 1)
|
| 257 |
+
"""
|
| 258 |
+
# source for harmonic branch
|
| 259 |
+
with torch.no_grad():
|
| 260 |
+
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
|
| 261 |
+
sine_wavs = sine_wavs.transpose(1, 2)
|
| 262 |
+
uv = uv.transpose(1, 2)
|
| 263 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
| 264 |
+
|
| 265 |
+
# source for noise branch, in the same shape as uv
|
| 266 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
| 267 |
+
return sine_merge, noise, uv
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class SineGen2(torch.nn.Module):
|
| 271 |
+
""" Definition of sine generator
|
| 272 |
+
SineGen(samp_rate, harmonic_num = 0,
|
| 273 |
+
sine_amp = 0.1, noise_std = 0.003,
|
| 274 |
+
voiced_threshold = 0,
|
| 275 |
+
flag_for_pulse=False)
|
| 276 |
+
samp_rate: sampling rate in Hz
|
| 277 |
+
harmonic_num: number of harmonic overtones (default 0)
|
| 278 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
| 279 |
+
noise_std: std of Gaussian noise (default 0.003)
|
| 280 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
| 281 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
| 282 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
| 283 |
+
segment is always sin(np.pi) or cos(0)
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
|
| 287 |
+
sine_amp=0.1, noise_std=0.003,
|
| 288 |
+
voiced_threshold=0,
|
| 289 |
+
flag_for_pulse=False):
|
| 290 |
+
super(SineGen2, self).__init__()
|
| 291 |
+
self.sine_amp = sine_amp
|
| 292 |
+
self.noise_std = noise_std
|
| 293 |
+
self.harmonic_num = harmonic_num
|
| 294 |
+
self.dim = self.harmonic_num + 1
|
| 295 |
+
self.sampling_rate = samp_rate
|
| 296 |
+
self.voiced_threshold = voiced_threshold
|
| 297 |
+
self.flag_for_pulse = flag_for_pulse
|
| 298 |
+
self.upsample_scale = upsample_scale
|
| 299 |
+
|
| 300 |
+
def _f02uv(self, f0):
|
| 301 |
+
# generate uv signal
|
| 302 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
| 303 |
+
return uv
|
| 304 |
+
|
| 305 |
+
def _f02sine(self, f0_values):
|
| 306 |
+
""" f0_values: (batchsize, length, dim)
|
| 307 |
+
where dim indicates fundamental tone and overtones
|
| 308 |
+
"""
|
| 309 |
+
# convert to F0 in rad. The interger part n can be ignored
|
| 310 |
+
# because 2 * np.pi * n doesn't affect phase
|
| 311 |
+
rad_values = (f0_values / self.sampling_rate) % 1
|
| 312 |
+
|
| 313 |
+
# initial phase noise (no noise for fundamental component)
|
| 314 |
+
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
|
| 315 |
+
rand_ini[:, 0] = 0
|
| 316 |
+
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
| 317 |
+
|
| 318 |
+
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
| 319 |
+
if not self.flag_for_pulse:
|
| 320 |
+
rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
|
| 321 |
+
scale_factor=1 / self.upsample_scale,
|
| 322 |
+
mode="linear").transpose(1, 2)
|
| 323 |
+
|
| 324 |
+
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
| 325 |
+
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
|
| 326 |
+
scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
|
| 327 |
+
sines = torch.sin(phase)
|
| 328 |
+
else:
|
| 329 |
+
# If necessary, make sure that the first time step of every
|
| 330 |
+
# voiced segments is sin(pi) or cos(0)
|
| 331 |
+
# This is used for pulse-train generation
|
| 332 |
+
|
| 333 |
+
# identify the last time step in unvoiced segments
|
| 334 |
+
uv = self._f02uv(f0_values)
|
| 335 |
+
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
| 336 |
+
uv_1[:, -1, :] = 1
|
| 337 |
+
u_loc = (uv < 1) * (uv_1 > 0)
|
| 338 |
+
|
| 339 |
+
# get the instantanouse phase
|
| 340 |
+
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
| 341 |
+
# different batch needs to be processed differently
|
| 342 |
+
for idx in range(f0_values.shape[0]):
|
| 343 |
+
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
| 344 |
+
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
| 345 |
+
# stores the accumulation of i.phase within
|
| 346 |
+
# each voiced segments
|
| 347 |
+
tmp_cumsum[idx, :, :] = 0
|
| 348 |
+
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
| 349 |
+
|
| 350 |
+
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
| 351 |
+
# within the previous voiced segment.
|
| 352 |
+
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
| 353 |
+
|
| 354 |
+
# get the sines
|
| 355 |
+
sines = torch.cos(i_phase * 2 * np.pi)
|
| 356 |
+
return sines
|
| 357 |
+
|
| 358 |
+
def forward(self, f0):
|
| 359 |
+
""" sine_tensor, uv = forward(f0)
|
| 360 |
+
input F0: tensor(batchsize=1, length, dim=1)
|
| 361 |
+
f0 for unvoiced steps should be 0
|
| 362 |
+
output sine_tensor: tensor(batchsize=1, length, dim)
|
| 363 |
+
output uv: tensor(batchsize=1, length, 1)
|
| 364 |
+
"""
|
| 365 |
+
# fundamental component
|
| 366 |
+
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
|
| 367 |
+
|
| 368 |
+
# generate sine waveforms
|
| 369 |
+
sine_waves = self._f02sine(fn) * self.sine_amp
|
| 370 |
+
|
| 371 |
+
# generate uv signal
|
| 372 |
+
uv = self._f02uv(f0)
|
| 373 |
+
|
| 374 |
+
# noise: for unvoiced should be similar to sine_amp
|
| 375 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
| 376 |
+
# . for voiced regions is self.noise_std
|
| 377 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
| 378 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
| 379 |
+
|
| 380 |
+
# first: set the unvoiced part to 0 by uv
|
| 381 |
+
# then: additive noise
|
| 382 |
+
sine_waves = sine_waves * uv + noise
|
| 383 |
+
return sine_waves, uv, noise
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
class SourceModuleHnNSF2(torch.nn.Module):
|
| 387 |
+
""" SourceModule for hn-nsf
|
| 388 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
| 389 |
+
add_noise_std=0.003, voiced_threshod=0)
|
| 390 |
+
sampling_rate: sampling_rate in Hz
|
| 391 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
| 392 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
| 393 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
| 394 |
+
note that amplitude of noise in unvoiced is decided
|
| 395 |
+
by sine_amp
|
| 396 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
| 397 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
| 398 |
+
F0_sampled (batchsize, length, 1)
|
| 399 |
+
Sine_source (batchsize, length, 1)
|
| 400 |
+
noise_source (batchsize, length 1)
|
| 401 |
+
uv (batchsize, length, 1)
|
| 402 |
+
"""
|
| 403 |
+
|
| 404 |
+
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
| 405 |
+
add_noise_std=0.003, voiced_threshod=0):
|
| 406 |
+
super(SourceModuleHnNSF2, self).__init__()
|
| 407 |
+
|
| 408 |
+
self.sine_amp = sine_amp
|
| 409 |
+
self.noise_std = add_noise_std
|
| 410 |
+
|
| 411 |
+
# to produce sine waveforms
|
| 412 |
+
self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num,
|
| 413 |
+
sine_amp, add_noise_std, voiced_threshod)
|
| 414 |
+
|
| 415 |
+
# to merge source harmonics into a single excitation
|
| 416 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
| 417 |
+
self.l_tanh = torch.nn.Tanh()
|
| 418 |
+
|
| 419 |
+
def forward(self, x):
|
| 420 |
+
"""
|
| 421 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
| 422 |
+
F0_sampled (batchsize, length, 1)
|
| 423 |
+
Sine_source (batchsize, length, 1)
|
| 424 |
+
noise_source (batchsize, length 1)
|
| 425 |
+
"""
|
| 426 |
+
# source for harmonic branch
|
| 427 |
+
with torch.no_grad():
|
| 428 |
+
sine_wavs, uv, _ = self.l_sin_gen(x)
|
| 429 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
| 430 |
+
|
| 431 |
+
# source for noise branch, in the same shape as uv
|
| 432 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
| 433 |
+
return sine_merge, noise, uv
|
soulxpodcast/models/modules/sampler.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from typing import Any, Callable, Optional, Union
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from transformers.generation.logits_process import (
|
| 7 |
+
LogitsProcessorList
|
| 8 |
+
)
|
| 9 |
+
from transformers.generation.stopping_criteria import (
|
| 10 |
+
StoppingCriteriaList
|
| 11 |
+
)
|
| 12 |
+
from transformers.generation.configuration_utils import (
|
| 13 |
+
GenerationConfig
|
| 14 |
+
)
|
| 15 |
+
from transformers.generation.streamers import BaseStreamer
|
| 16 |
+
from transformers.generation.utils import (
|
| 17 |
+
GenerateNonBeamOutput,
|
| 18 |
+
GenerateEncoderDecoderOutput,
|
| 19 |
+
GenerateDecoderOnlyOutput,
|
| 20 |
+
)
|
| 21 |
+
from transformers import StoppingCriteria
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _ras_sample_hf_engine(
|
| 25 |
+
self,
|
| 26 |
+
input_ids: torch.LongTensor,
|
| 27 |
+
logits_processor: LogitsProcessorList,
|
| 28 |
+
stopping_criteria: StoppingCriteriaList,
|
| 29 |
+
generation_config: GenerationConfig,
|
| 30 |
+
synced_gpus: bool = False,
|
| 31 |
+
streamer: Optional["BaseStreamer"] = None,
|
| 32 |
+
use_ras=False,
|
| 33 |
+
win_size=25,
|
| 34 |
+
tau_r=0.2,
|
| 35 |
+
**model_kwargs,
|
| 36 |
+
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
| 37 |
+
r"""
|
| 38 |
+
Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
|
| 39 |
+
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
| 40 |
+
|
| 41 |
+
Parameters:
|
| 42 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 43 |
+
The sequence used as a prompt for the generation.
|
| 44 |
+
logits_processor (`LogitsProcessorList`):
|
| 45 |
+
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
| 46 |
+
used to modify the prediction scores of the language modeling head applied at each generation step.
|
| 47 |
+
stopping_criteria (`StoppingCriteriaList`):
|
| 48 |
+
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
| 49 |
+
used to tell if the generation loop should stop.
|
| 50 |
+
generation_config ([`~generation.GenerationConfig`]):
|
| 51 |
+
The generation configuration to be used as parametrization of the decoding method.
|
| 52 |
+
synced_gpus (`bool`):
|
| 53 |
+
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
| 54 |
+
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
| 55 |
+
streamer (`BaseStreamer`, *optional*):
|
| 56 |
+
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
| 57 |
+
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
| 58 |
+
model_kwargs:
|
| 59 |
+
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
| 60 |
+
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
| 61 |
+
|
| 62 |
+
Return:
|
| 63 |
+
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
|
| 64 |
+
A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
| 65 |
+
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
| 66 |
+
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
|
| 67 |
+
`model.config.is_encoder_decoder=True`.
|
| 68 |
+
"""
|
| 69 |
+
# init values
|
| 70 |
+
pad_token_id = generation_config._pad_token_tensor
|
| 71 |
+
output_attentions = generation_config.output_attentions
|
| 72 |
+
output_hidden_states = generation_config.output_hidden_states
|
| 73 |
+
output_scores = generation_config.output_scores
|
| 74 |
+
output_logits = generation_config.output_logits
|
| 75 |
+
return_dict_in_generate = generation_config.return_dict_in_generate
|
| 76 |
+
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
| 77 |
+
do_sample = generation_config.do_sample
|
| 78 |
+
|
| 79 |
+
# init attention / hidden states / scores tuples
|
| 80 |
+
scores = () if (return_dict_in_generate and output_scores) else None
|
| 81 |
+
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
| 82 |
+
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
| 83 |
+
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
| 84 |
+
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
| 85 |
+
|
| 86 |
+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
| 87 |
+
if return_dict_in_generate and self.config.is_encoder_decoder:
|
| 88 |
+
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
| 89 |
+
encoder_hidden_states = (
|
| 90 |
+
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# keep track of which sequences are already finished
|
| 94 |
+
batch_size, cur_len = input_ids.shape[:2]
|
| 95 |
+
this_peer_finished = False
|
| 96 |
+
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
| 97 |
+
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
| 98 |
+
|
| 99 |
+
model_forward = self.__call__
|
| 100 |
+
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
|
| 101 |
+
if compile_forward:
|
| 102 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
| 103 |
+
model_forward = self.get_compiled_call(generation_config.compile_config)
|
| 104 |
+
|
| 105 |
+
if generation_config.prefill_chunk_size is not None:
|
| 106 |
+
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
|
| 107 |
+
is_prefill = False
|
| 108 |
+
else:
|
| 109 |
+
is_prefill = True
|
| 110 |
+
|
| 111 |
+
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
| 112 |
+
# prepare model inputs
|
| 113 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
| 114 |
+
|
| 115 |
+
# prepare variable output controls (note: some models won't accept all output controls)
|
| 116 |
+
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
|
| 117 |
+
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
|
| 118 |
+
|
| 119 |
+
if is_prefill:
|
| 120 |
+
outputs = self(**model_inputs, return_dict=True)
|
| 121 |
+
is_prefill = False
|
| 122 |
+
else:
|
| 123 |
+
outputs = model_forward(**model_inputs, return_dict=True)
|
| 124 |
+
|
| 125 |
+
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
| 126 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
| 127 |
+
outputs,
|
| 128 |
+
model_kwargs,
|
| 129 |
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
| 130 |
+
)
|
| 131 |
+
if synced_gpus and this_peer_finished:
|
| 132 |
+
continue
|
| 133 |
+
|
| 134 |
+
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
| 135 |
+
# (the clone itself is always small)
|
| 136 |
+
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# pre-process distribution
|
| 140 |
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
| 141 |
+
|
| 142 |
+
# Repetition Aware Sampling in VALL-E 2
|
| 143 |
+
if use_ras:
|
| 144 |
+
probs_candidate = nn.functional.softmax(next_token_scores, dim=-1)
|
| 145 |
+
next_tokens_candidate = torch.multinomial(probs_candidate, num_samples=1).squeeze(1)
|
| 146 |
+
rep_num = (input_ids[:,-win_size:] == next_tokens_candidate).sum().item() + 1
|
| 147 |
+
if rep_num >= win_size * tau_r:
|
| 148 |
+
next_token_scores = next_token_logits
|
| 149 |
+
|
| 150 |
+
# Store scores, attentions and hidden_states when required
|
| 151 |
+
if return_dict_in_generate:
|
| 152 |
+
if output_scores:
|
| 153 |
+
scores += (next_token_scores,)
|
| 154 |
+
if output_logits:
|
| 155 |
+
raw_logits += (next_token_logits,)
|
| 156 |
+
if output_attentions:
|
| 157 |
+
decoder_attentions += (
|
| 158 |
+
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
| 159 |
+
)
|
| 160 |
+
if self.config.is_encoder_decoder:
|
| 161 |
+
cross_attentions += (outputs.cross_attentions,)
|
| 162 |
+
|
| 163 |
+
if output_hidden_states:
|
| 164 |
+
decoder_hidden_states += (
|
| 165 |
+
(outputs.decoder_hidden_states,)
|
| 166 |
+
if self.config.is_encoder_decoder
|
| 167 |
+
else (outputs.hidden_states,)
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# token selection
|
| 171 |
+
if do_sample:
|
| 172 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
| 173 |
+
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
|
| 174 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 175 |
+
else:
|
| 176 |
+
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
| 177 |
+
|
| 178 |
+
# finished sentences should have their next token be a padding token
|
| 179 |
+
if has_eos_stopping_criteria:
|
| 180 |
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
| 181 |
+
|
| 182 |
+
# update generated ids, model inputs, and length for next step
|
| 183 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
| 184 |
+
if streamer is not None:
|
| 185 |
+
streamer.put(next_tokens.cpu())
|
| 186 |
+
|
| 187 |
+
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
| 188 |
+
this_peer_finished = unfinished_sequences.max() == 0
|
| 189 |
+
cur_len += 1
|
| 190 |
+
|
| 191 |
+
# This is needed to properly delete outputs.logits which may be very large for first iteration
|
| 192 |
+
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
|
| 193 |
+
del outputs
|
| 194 |
+
|
| 195 |
+
if streamer is not None:
|
| 196 |
+
streamer.end()
|
| 197 |
+
|
| 198 |
+
if return_dict_in_generate:
|
| 199 |
+
if self.config.is_encoder_decoder:
|
| 200 |
+
return GenerateEncoderDecoderOutput(
|
| 201 |
+
sequences=input_ids,
|
| 202 |
+
scores=scores,
|
| 203 |
+
logits=raw_logits,
|
| 204 |
+
encoder_attentions=encoder_attentions,
|
| 205 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 206 |
+
decoder_attentions=decoder_attentions,
|
| 207 |
+
cross_attentions=cross_attentions,
|
| 208 |
+
decoder_hidden_states=decoder_hidden_states,
|
| 209 |
+
past_key_values=model_kwargs.get("past_key_values"),
|
| 210 |
+
)
|
| 211 |
+
else:
|
| 212 |
+
return GenerateDecoderOnlyOutput(
|
| 213 |
+
sequences=input_ids,
|
| 214 |
+
scores=scores,
|
| 215 |
+
logits=raw_logits,
|
| 216 |
+
attentions=decoder_attentions,
|
| 217 |
+
hidden_states=decoder_hidden_states,
|
| 218 |
+
past_key_values=model_kwargs.get("past_key_values"),
|
| 219 |
+
)
|
| 220 |
+
else:
|
| 221 |
+
return input_ids
|
soulxpodcast/models/soulxpodcast.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
from itertools import chain
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from copy import deepcopy
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import s3tokenizer
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
|
| 12 |
+
from soulxpodcast.config import Config, SamplingParams, AutoPretrainedConfig
|
| 13 |
+
from soulxpodcast.engine.llm_engine import (
|
| 14 |
+
HFLLMEngine, VLLMEngine
|
| 15 |
+
)
|
| 16 |
+
from soulxpodcast.models.modules.flow import CausalMaskedDiffWithXvec
|
| 17 |
+
from soulxpodcast.models.modules.hifigan import HiFTGenerator
|
| 18 |
+
|
| 19 |
+
class SoulXPodcast(torch.nn.Module):
|
| 20 |
+
def __init__(self, config: Config = None):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.config = Config() if config is None else config
|
| 23 |
+
|
| 24 |
+
self.audio_tokenizer = s3tokenizer.load_model("speech_tokenizer_v2_25hz").cuda().eval()
|
| 25 |
+
if self.config.llm_engine == "hf":
|
| 26 |
+
self.llm = HFLLMEngine(**self.config.__dict__)
|
| 27 |
+
elif self.config.llm_engine == "vllm":
|
| 28 |
+
self.llm = VLLMEngine(**self.config.__dict__)
|
| 29 |
+
else:
|
| 30 |
+
raise NotImplementedError
|
| 31 |
+
|
| 32 |
+
self.use_tqdm = True
|
| 33 |
+
|
| 34 |
+
self.flow = CausalMaskedDiffWithXvec()
|
| 35 |
+
if self.config.hf_config.fp16_flow:
|
| 36 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
| 37 |
+
tqdm.write(f"[{timestamp}] - [INFO] - Casting flow to fp16")
|
| 38 |
+
self.flow.half()
|
| 39 |
+
self.flow.load_state_dict(torch.load(f"{self.config.model}/flow.pt", map_location="cpu", weights_only=True), strict=True)
|
| 40 |
+
self.flow.cuda().eval()
|
| 41 |
+
|
| 42 |
+
self.hift = HiFTGenerator()
|
| 43 |
+
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{self.config.model}/hift.pt", map_location="cpu", weights_only=True).items()}
|
| 44 |
+
self.hift.load_state_dict(hift_state_dict, strict=True)
|
| 45 |
+
self.hift.cuda().eval()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@torch.inference_mode()
|
| 49 |
+
def forward_longform(
|
| 50 |
+
self, prompt_mels_for_llm,
|
| 51 |
+
prompt_mels_lens_for_llm: torch.Tensor,
|
| 52 |
+
prompt_text_tokens_for_llm: list[list[int]],
|
| 53 |
+
text_tokens_for_llm: list[list[int]],
|
| 54 |
+
prompt_mels_for_flow_ori,
|
| 55 |
+
spk_emb_for_flow: torch.Tensor,
|
| 56 |
+
sampling_params: SamplingParams | list[SamplingParams],
|
| 57 |
+
spk_ids: list[list[int]],
|
| 58 |
+
use_prompt_cot: bool = False,
|
| 59 |
+
prompt_cot_text_tokens_for_llm: list[list[int]] = None,
|
| 60 |
+
prompt_cot_prefix: list[list[int]] = None,
|
| 61 |
+
**kwargs, # for compatibility
|
| 62 |
+
):
|
| 63 |
+
prompt_size, turn_size = len(prompt_mels_for_llm), len(text_tokens_for_llm)
|
| 64 |
+
|
| 65 |
+
# Audio tokenization
|
| 66 |
+
prompt_speech_tokens_ori, prompt_speech_tokens_lens_ori = self.audio_tokenizer.quantize(
|
| 67 |
+
prompt_mels_for_llm.cuda(), prompt_mels_lens_for_llm.cuda()
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# align speech token with speech feat as to reduce
|
| 71 |
+
# the noise ratio during the generation process.
|
| 72 |
+
prompt_speech_tokens = []
|
| 73 |
+
prompt_mels_for_flow, prompt_mels_lens_for_flow = [], []
|
| 74 |
+
|
| 75 |
+
for prompt_index in range(prompt_size):
|
| 76 |
+
prompt_speech_token_len = prompt_speech_tokens_lens_ori[prompt_index].item()
|
| 77 |
+
prompt_speech_token = prompt_speech_tokens_ori[prompt_index, :prompt_speech_token_len]
|
| 78 |
+
prompt_mel = prompt_mels_for_flow_ori[prompt_index]
|
| 79 |
+
prompt_mel_len = prompt_mel.shape[0]
|
| 80 |
+
if prompt_speech_token_len * 2 > prompt_mel_len:
|
| 81 |
+
prompt_speech_token = prompt_speech_token[:int(prompt_mel_len/2)]
|
| 82 |
+
prompt_mel_len = torch.tensor([prompt_mel_len]).cuda()
|
| 83 |
+
else:
|
| 84 |
+
prompt_mel = prompt_mel.detach().clone()[:prompt_speech_token_len * 2].cuda()
|
| 85 |
+
prompt_mel_len = torch.tensor([prompt_speech_token_len * 2]).cuda()
|
| 86 |
+
prompt_speech_tokens.append(prompt_speech_token)
|
| 87 |
+
prompt_mels_for_flow.append(prompt_mel)
|
| 88 |
+
prompt_mels_lens_for_flow.append(prompt_mel_len)
|
| 89 |
+
|
| 90 |
+
# Prepare LLM inputs
|
| 91 |
+
prompt_inputs = []
|
| 92 |
+
history_inputs = []
|
| 93 |
+
|
| 94 |
+
# for i in range(prompt_size):
|
| 95 |
+
# prompt_mels = prompt_mels_for_flow[i][None]
|
| 96 |
+
# prompt_mels_lens = prompt_mels_lens_for_flow[i]
|
| 97 |
+
# spk_emb = spk_emb_for_flow[i:i+1]
|
| 98 |
+
|
| 99 |
+
# # Flow generation
|
| 100 |
+
# with torch.amp.autocast("cuda", dtype=torch.float16 if self.config.hf_config.fp16_flow else torch.float32):
|
| 101 |
+
# flow_input = torch.concat([prompt_speech_tokens[i].detach().clone(), prompt_speech_tokens[1].detach().clone()], axis=0)[None]
|
| 102 |
+
# flow_inputs_len = torch.tensor(flow_input.shape[1])[None]
|
| 103 |
+
# generated_mels, generated_mels_lens = self.flow(
|
| 104 |
+
# flow_input.cuda(), flow_inputs_len.cuda(),
|
| 105 |
+
# prompt_mels, prompt_mels_lens, spk_emb.cuda(),
|
| 106 |
+
# streaming=False, finalize=True
|
| 107 |
+
# )
|
| 108 |
+
|
| 109 |
+
# # HiFi-GAN generation
|
| 110 |
+
# mel = generated_mels[:, :, prompt_mels_lens[0].item():generated_mels_lens[0].item()]
|
| 111 |
+
# wav, _ = self.hift(speech_feat=mel)
|
| 112 |
+
# import soundfile as sf
|
| 113 |
+
# sf.write(f"{str(i).zfill(2)}.wav", wav.cpu().squeeze(0).numpy(), 24000)
|
| 114 |
+
|
| 115 |
+
for i in range(prompt_size):
|
| 116 |
+
speech_tokens_i = [token+self.config.hf_config.speech_token_offset for token in prompt_speech_tokens[i].tolist()]
|
| 117 |
+
speech_tokens_i += [self.config.hf_config.eos_token_id]
|
| 118 |
+
if use_prompt_cot and len(prompt_cot_text_tokens_for_llm[i])>0:
|
| 119 |
+
prompt_cot_input = prompt_text_tokens_for_llm[i] + speech_tokens_i + prompt_cot_text_tokens_for_llm[i]
|
| 120 |
+
if i>0:
|
| 121 |
+
prompt_cot_input = prompt_cot_prefix[0] + prompt_cot_input
|
| 122 |
+
cot_input = self.llm.generate(prompt_cot_input, sampling_params, past_key_values=None)['token_ids']
|
| 123 |
+
prompt_inputs.append(prompt_cot_prefix[i+1]+prompt_cot_text_tokens_for_llm[i] + cot_input)
|
| 124 |
+
history_inputs.append(prompt_cot_prefix[i+1]+prompt_cot_text_tokens_for_llm[i] + cot_input)
|
| 125 |
+
else:
|
| 126 |
+
prompt_inputs.append(prompt_text_tokens_for_llm[i] + speech_tokens_i )
|
| 127 |
+
history_inputs.append(prompt_text_tokens_for_llm[i] + speech_tokens_i )
|
| 128 |
+
|
| 129 |
+
generated_wavs, results_dict = [], {}
|
| 130 |
+
|
| 131 |
+
# LLM generation
|
| 132 |
+
inputs = list(chain.from_iterable(prompt_inputs))
|
| 133 |
+
cache_config = AutoPretrainedConfig().from_dataclass(self.llm.config.hf_config)
|
| 134 |
+
past_key_values = DynamicCache(config=cache_config)
|
| 135 |
+
valid_turn_size = prompt_size
|
| 136 |
+
for i in range(turn_size):
|
| 137 |
+
|
| 138 |
+
# # set ratio: reach the reset cache ratio;
|
| 139 |
+
if valid_turn_size > self.config.max_turn_size or len(inputs)>self.config.turn_tokens_threshold:
|
| 140 |
+
assert self.config.max_turn_size >= self.config.prompt_context + self.config.history_context, "Invalid Long history size setting, "
|
| 141 |
+
prompt_text_bound = max(self.config.prompt_context, len(history_inputs)-self.config.history_text_context-self.config.history_context)
|
| 142 |
+
inputs = list(chain.from_iterable(
|
| 143 |
+
history_inputs[:self.config.prompt_context]+ \
|
| 144 |
+
history_inputs[prompt_text_bound:-self.config.history_context]+ \
|
| 145 |
+
prompt_inputs[-self.config.history_context:]
|
| 146 |
+
))
|
| 147 |
+
valid_turn_size = self.config.prompt_context + len(history_inputs) - prompt_text_bound
|
| 148 |
+
past_key_values = DynamicCache(config=cache_config)
|
| 149 |
+
valid_turn_size += 1
|
| 150 |
+
|
| 151 |
+
inputs.extend(text_tokens_for_llm[i])
|
| 152 |
+
start_time = time.time()
|
| 153 |
+
llm_outputs = self.llm.generate(inputs, sampling_params, past_key_values=past_key_values)
|
| 154 |
+
|
| 155 |
+
inputs.extend(llm_outputs['token_ids'])
|
| 156 |
+
prompt_inputs.append(text_tokens_for_llm[i]+llm_outputs['token_ids'])
|
| 157 |
+
history_inputs.append(text_tokens_for_llm[i][:-1]) # remove the <|audio_start|>
|
| 158 |
+
|
| 159 |
+
# Prepare Flow inputs
|
| 160 |
+
turn_spk = spk_ids[i]
|
| 161 |
+
generated_speech_tokens = [token - self.config.hf_config.speech_token_offset for token in llm_outputs['token_ids'][:-1]] # ignore last eos
|
| 162 |
+
prompt_speech_token = prompt_speech_tokens[turn_spk].tolist()
|
| 163 |
+
flow_input = torch.tensor([prompt_speech_token + generated_speech_tokens])
|
| 164 |
+
flow_inputs_len = torch.tensor([len(prompt_speech_token) + len(generated_speech_tokens)])
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# Flow generation and HiFi-GAN generation
|
| 168 |
+
start_idx = spk_ids[i]
|
| 169 |
+
prompt_mels = prompt_mels_for_flow[start_idx][None]
|
| 170 |
+
prompt_mels_lens = prompt_mels_lens_for_flow[start_idx][None]
|
| 171 |
+
spk_emb = spk_emb_for_flow[start_idx:start_idx+1]
|
| 172 |
+
|
| 173 |
+
# Flow generation
|
| 174 |
+
with torch.amp.autocast("cuda", dtype=torch.float16 if self.config.hf_config.fp16_flow else torch.float32):
|
| 175 |
+
generated_mels, generated_mels_lens = self.flow(
|
| 176 |
+
flow_input.cuda(), flow_inputs_len.cuda(),
|
| 177 |
+
prompt_mels, prompt_mels_lens, spk_emb.cuda(),
|
| 178 |
+
streaming=False, finalize=True
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# HiFi-GAN generation
|
| 182 |
+
mel = generated_mels[:, :, prompt_mels_lens[0].item():generated_mels_lens[0].item()]
|
| 183 |
+
try:
|
| 184 |
+
wav, _ = self.hift(speech_feat=mel)
|
| 185 |
+
except Exception as e:
|
| 186 |
+
import pdb;pdb.set_trace()
|
| 187 |
+
print(e)
|
| 188 |
+
generated_wavs.append(wav)
|
| 189 |
+
|
| 190 |
+
# Save the generated wav;
|
| 191 |
+
results_dict['generated_wavs'] = generated_wavs
|
| 192 |
+
return results_dict
|
soulxpodcast/utils/__init__.py
ADDED
|
File without changes
|