Update README.md
Browse files
README.md
CHANGED
@@ -1,4 +1,169 @@
|
|
1 |
---
|
2 |
base_model:
|
3 |
- m-a-p/YuE-s1-7B-anneal-en-cot
|
4 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
base_model:
|
3 |
- m-a-p/YuE-s1-7B-anneal-en-cot
|
4 |
+
---
|
5 |
+
|
6 |
+
# Sample Inference Script
|
7 |
+
```py
|
8 |
+
import random
|
9 |
+
import re
|
10 |
+
import sys
|
11 |
+
from argparse import ArgumentParser
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
sys.path.append("xcodec_mini_infer")
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torchaudio
|
18 |
+
import yaml
|
19 |
+
from exllamav2 import (
|
20 |
+
ExLlamaV2,
|
21 |
+
ExLlamaV2Cache,
|
22 |
+
ExLlamaV2Config,
|
23 |
+
ExLlamaV2Tokenizer,
|
24 |
+
Timer,
|
25 |
+
)
|
26 |
+
from exllamav2.generator import (
|
27 |
+
ExLlamaV2DynamicGenerator,
|
28 |
+
ExLlamaV2DynamicJob,
|
29 |
+
ExLlamaV2Sampler,
|
30 |
+
)
|
31 |
+
from rich import print
|
32 |
+
|
33 |
+
from xcodec_mini_infer.models.soundstream_hubert_new import SoundStream
|
34 |
+
|
35 |
+
parser = ArgumentParser()
|
36 |
+
parser.add_argument("-s1", "--stage-1", required=True)
|
37 |
+
parser.add_argument("-g", "--genre", default="genre.txt")
|
38 |
+
parser.add_argument("-l", "--lyrics", default="lyrics.txt")
|
39 |
+
parser.add_argument("-d", "--debug", action="store_true")
|
40 |
+
parser.add_argument("-s", "--seed", type=int, default=None)
|
41 |
+
parser.add_argument("--sample_rate", type=int, default=16000)
|
42 |
+
parser.add_argument("--repetition_penalty", type=float, default=1.2)
|
43 |
+
parser.add_argument("--temperature", type=float, default=1.0)
|
44 |
+
parser.add_argument("--top_p", type=float, default=0.93)
|
45 |
+
args = parser.parse_args()
|
46 |
+
|
47 |
+
with Timer() as timer:
|
48 |
+
config = ExLlamaV2Config(args.stage_1)
|
49 |
+
model = ExLlamaV2(config, lazy_load=True)
|
50 |
+
cache = ExLlamaV2Cache(model, lazy=True)
|
51 |
+
model.load_autosplit(cache)
|
52 |
+
|
53 |
+
tokenizer = ExLlamaV2Tokenizer(config, lazy_init=True)
|
54 |
+
generator = ExLlamaV2DynamicGenerator(model, cache, tokenizer)
|
55 |
+
generator.warmup()
|
56 |
+
|
57 |
+
print(f"Loaded stage 1 model in {timer.interval:.2f} seconds.")
|
58 |
+
|
59 |
+
genre = Path(args.genre)
|
60 |
+
genre = genre.read_text(encoding="utf-8") if genre.is_file() else args.genre
|
61 |
+
genre = genre.strip()
|
62 |
+
|
63 |
+
lyrics = Path(args.lyrics)
|
64 |
+
lyrics = lyrics.read_text(encoding="utf-8") if lyrics.is_file() else args.lyrics
|
65 |
+
lyrics = lyrics.strip()
|
66 |
+
|
67 |
+
lyrics = re.findall(r"\[(\w+)\](.*?)\n(?=\[|\Z)", lyrics, re.DOTALL)
|
68 |
+
lyrics = [f"[{l[0]}]\n{l[1].strip()}\n\n" for l in lyrics]
|
69 |
+
lyrics_joined = "\n".join(lyrics)
|
70 |
+
|
71 |
+
gen_settings = ExLlamaV2Sampler.Settings()
|
72 |
+
gen_settings.allow_tokens(tokenizer, [32002] + list(range(45334, 46358)))
|
73 |
+
gen_settings.temperature = args.temperature
|
74 |
+
gen_settings.token_repetition_penalty = args.repetition_penalty
|
75 |
+
gen_settings.top_p = args.top_p
|
76 |
+
|
77 |
+
seed = args.seed if args.seed else random.randint(0, 2**64 - 1)
|
78 |
+
stop_conditions = ["<EOA>"]
|
79 |
+
|
80 |
+
output_joined = ""
|
81 |
+
output = []
|
82 |
+
|
83 |
+
with Timer() as timer:
|
84 |
+
for segment in lyrics:
|
85 |
+
current = []
|
86 |
+
|
87 |
+
input = (
|
88 |
+
"Generate music from the given lyrics segment by segment.\n"
|
89 |
+
f"[Genre] {genre}\n"
|
90 |
+
f"{lyrics_joined}{output_joined}[start_of_segment]{segment}<SOA><xcodec>"
|
91 |
+
)
|
92 |
+
|
93 |
+
input_ids = tokenizer.encode(input, encode_special_tokens=True)
|
94 |
+
input_len = input_ids.shape[-1]
|
95 |
+
max_new_tokens = config.max_seq_len - input_len
|
96 |
+
|
97 |
+
print(
|
98 |
+
f"Using {input_len} tokens of {config.max_seq_len} tokens "
|
99 |
+
f"with {max_new_tokens} tokens left."
|
100 |
+
)
|
101 |
+
|
102 |
+
job = ExLlamaV2DynamicJob(
|
103 |
+
input_ids=input_ids,
|
104 |
+
max_new_tokens=max_new_tokens,
|
105 |
+
gen_settings=gen_settings,
|
106 |
+
seed=seed,
|
107 |
+
stop_conditions=stop_conditions,
|
108 |
+
decode_special_tokens=True,
|
109 |
+
)
|
110 |
+
|
111 |
+
generator.enqueue(job)
|
112 |
+
|
113 |
+
with Timer() as inner:
|
114 |
+
while generator.num_remaining_jobs():
|
115 |
+
for result in generator.iterate():
|
116 |
+
if result.get("stage") == "streaming":
|
117 |
+
text = result.get("text")
|
118 |
+
|
119 |
+
if text:
|
120 |
+
current.append(text)
|
121 |
+
output.append(text)
|
122 |
+
|
123 |
+
if args.debug:
|
124 |
+
print(text, end="", flush=True)
|
125 |
+
|
126 |
+
if result.get("eos") and current:
|
127 |
+
current_joined = "".join(current)
|
128 |
+
output_joined += (
|
129 |
+
f"[start_of_segment]{segment}<SOA><xcodec>"
|
130 |
+
f"{current_joined}<EOA>[end_of_segment]"
|
131 |
+
)
|
132 |
+
|
133 |
+
if args.debug:
|
134 |
+
print()
|
135 |
+
|
136 |
+
print(f"Generated {len(current)} tokens in {inner.interval:.2f} seconds.")
|
137 |
+
|
138 |
+
print(f"Finished in {timer.interval:.2f} seconds with seed {seed}.")
|
139 |
+
|
140 |
+
with Timer() as timer:
|
141 |
+
codec_config = Path("xcodec_mini_infer/final_ckpt/config.yaml")
|
142 |
+
codec_config = yaml.safe_load(codec_config.read_bytes())
|
143 |
+
codec = SoundStream(**codec_config["generator"]["config"])
|
144 |
+
state_dict = torch.load("xcodec_mini_infer/final_ckpt/ckpt_00360000.pth")
|
145 |
+
codec.load_state_dict(state_dict["codec_model"])
|
146 |
+
codec = codec.eval().cuda()
|
147 |
+
|
148 |
+
print(f"Loaded codec in {timer.interval:.2f} seconds.")
|
149 |
+
|
150 |
+
with Timer() as timer, torch.inference_mode():
|
151 |
+
pattern = re.compile(r"<xcodec/0/(\d+)>")
|
152 |
+
output_ids = [int(o[10:-1]) for o in output if re.match(pattern, o)]
|
153 |
+
|
154 |
+
vocal = output_ids[::2]
|
155 |
+
vocal = torch.tensor([[vocal]]).cuda()
|
156 |
+
vocal = vocal.permute(1, 0, 2)
|
157 |
+
vocal = codec.decode(vocal)
|
158 |
+
vocal = vocal.squeeze(0).cpu()
|
159 |
+
torchaudio.save("vocal.wav", vocal, args.sample_rate)
|
160 |
+
|
161 |
+
inst = output_ids[1::2]
|
162 |
+
inst = torch.tensor([[inst]]).cuda()
|
163 |
+
inst = inst.permute(1, 0, 2)
|
164 |
+
inst = codec.decode(inst)
|
165 |
+
inst = inst.squeeze(0).cpu()
|
166 |
+
torchaudio.save("inst.wav", inst, args.sample_rate)
|
167 |
+
|
168 |
+
print(f"Decoded audio in {timer.interval:.2f} seconds.")
|
169 |
+
```
|