Annuvin commited on
Commit
87043db
·
verified ·
1 Parent(s): 8c28a4c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +166 -1
README.md CHANGED
@@ -1,4 +1,169 @@
1
  ---
2
  base_model:
3
  - m-a-p/YuE-s1-7B-anneal-en-cot
4
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  base_model:
3
  - m-a-p/YuE-s1-7B-anneal-en-cot
4
+ ---
5
+
6
+ # Sample Inference Script
7
+ ```py
8
+ import random
9
+ import re
10
+ import sys
11
+ from argparse import ArgumentParser
12
+ from pathlib import Path
13
+
14
+ sys.path.append("xcodec_mini_infer")
15
+
16
+ import torch
17
+ import torchaudio
18
+ import yaml
19
+ from exllamav2 import (
20
+ ExLlamaV2,
21
+ ExLlamaV2Cache,
22
+ ExLlamaV2Config,
23
+ ExLlamaV2Tokenizer,
24
+ Timer,
25
+ )
26
+ from exllamav2.generator import (
27
+ ExLlamaV2DynamicGenerator,
28
+ ExLlamaV2DynamicJob,
29
+ ExLlamaV2Sampler,
30
+ )
31
+ from rich import print
32
+
33
+ from xcodec_mini_infer.models.soundstream_hubert_new import SoundStream
34
+
35
+ parser = ArgumentParser()
36
+ parser.add_argument("-s1", "--stage-1", required=True)
37
+ parser.add_argument("-g", "--genre", default="genre.txt")
38
+ parser.add_argument("-l", "--lyrics", default="lyrics.txt")
39
+ parser.add_argument("-d", "--debug", action="store_true")
40
+ parser.add_argument("-s", "--seed", type=int, default=None)
41
+ parser.add_argument("--sample_rate", type=int, default=16000)
42
+ parser.add_argument("--repetition_penalty", type=float, default=1.2)
43
+ parser.add_argument("--temperature", type=float, default=1.0)
44
+ parser.add_argument("--top_p", type=float, default=0.93)
45
+ args = parser.parse_args()
46
+
47
+ with Timer() as timer:
48
+ config = ExLlamaV2Config(args.stage_1)
49
+ model = ExLlamaV2(config, lazy_load=True)
50
+ cache = ExLlamaV2Cache(model, lazy=True)
51
+ model.load_autosplit(cache)
52
+
53
+ tokenizer = ExLlamaV2Tokenizer(config, lazy_init=True)
54
+ generator = ExLlamaV2DynamicGenerator(model, cache, tokenizer)
55
+ generator.warmup()
56
+
57
+ print(f"Loaded stage 1 model in {timer.interval:.2f} seconds.")
58
+
59
+ genre = Path(args.genre)
60
+ genre = genre.read_text(encoding="utf-8") if genre.is_file() else args.genre
61
+ genre = genre.strip()
62
+
63
+ lyrics = Path(args.lyrics)
64
+ lyrics = lyrics.read_text(encoding="utf-8") if lyrics.is_file() else args.lyrics
65
+ lyrics = lyrics.strip()
66
+
67
+ lyrics = re.findall(r"\[(\w+)\](.*?)\n(?=\[|\Z)", lyrics, re.DOTALL)
68
+ lyrics = [f"[{l[0]}]\n{l[1].strip()}\n\n" for l in lyrics]
69
+ lyrics_joined = "\n".join(lyrics)
70
+
71
+ gen_settings = ExLlamaV2Sampler.Settings()
72
+ gen_settings.allow_tokens(tokenizer, [32002] + list(range(45334, 46358)))
73
+ gen_settings.temperature = args.temperature
74
+ gen_settings.token_repetition_penalty = args.repetition_penalty
75
+ gen_settings.top_p = args.top_p
76
+
77
+ seed = args.seed if args.seed else random.randint(0, 2**64 - 1)
78
+ stop_conditions = ["<EOA>"]
79
+
80
+ output_joined = ""
81
+ output = []
82
+
83
+ with Timer() as timer:
84
+ for segment in lyrics:
85
+ current = []
86
+
87
+ input = (
88
+ "Generate music from the given lyrics segment by segment.\n"
89
+ f"[Genre] {genre}\n"
90
+ f"{lyrics_joined}{output_joined}[start_of_segment]{segment}<SOA><xcodec>"
91
+ )
92
+
93
+ input_ids = tokenizer.encode(input, encode_special_tokens=True)
94
+ input_len = input_ids.shape[-1]
95
+ max_new_tokens = config.max_seq_len - input_len
96
+
97
+ print(
98
+ f"Using {input_len} tokens of {config.max_seq_len} tokens "
99
+ f"with {max_new_tokens} tokens left."
100
+ )
101
+
102
+ job = ExLlamaV2DynamicJob(
103
+ input_ids=input_ids,
104
+ max_new_tokens=max_new_tokens,
105
+ gen_settings=gen_settings,
106
+ seed=seed,
107
+ stop_conditions=stop_conditions,
108
+ decode_special_tokens=True,
109
+ )
110
+
111
+ generator.enqueue(job)
112
+
113
+ with Timer() as inner:
114
+ while generator.num_remaining_jobs():
115
+ for result in generator.iterate():
116
+ if result.get("stage") == "streaming":
117
+ text = result.get("text")
118
+
119
+ if text:
120
+ current.append(text)
121
+ output.append(text)
122
+
123
+ if args.debug:
124
+ print(text, end="", flush=True)
125
+
126
+ if result.get("eos") and current:
127
+ current_joined = "".join(current)
128
+ output_joined += (
129
+ f"[start_of_segment]{segment}<SOA><xcodec>"
130
+ f"{current_joined}<EOA>[end_of_segment]"
131
+ )
132
+
133
+ if args.debug:
134
+ print()
135
+
136
+ print(f"Generated {len(current)} tokens in {inner.interval:.2f} seconds.")
137
+
138
+ print(f"Finished in {timer.interval:.2f} seconds with seed {seed}.")
139
+
140
+ with Timer() as timer:
141
+ codec_config = Path("xcodec_mini_infer/final_ckpt/config.yaml")
142
+ codec_config = yaml.safe_load(codec_config.read_bytes())
143
+ codec = SoundStream(**codec_config["generator"]["config"])
144
+ state_dict = torch.load("xcodec_mini_infer/final_ckpt/ckpt_00360000.pth")
145
+ codec.load_state_dict(state_dict["codec_model"])
146
+ codec = codec.eval().cuda()
147
+
148
+ print(f"Loaded codec in {timer.interval:.2f} seconds.")
149
+
150
+ with Timer() as timer, torch.inference_mode():
151
+ pattern = re.compile(r"<xcodec/0/(\d+)>")
152
+ output_ids = [int(o[10:-1]) for o in output if re.match(pattern, o)]
153
+
154
+ vocal = output_ids[::2]
155
+ vocal = torch.tensor([[vocal]]).cuda()
156
+ vocal = vocal.permute(1, 0, 2)
157
+ vocal = codec.decode(vocal)
158
+ vocal = vocal.squeeze(0).cpu()
159
+ torchaudio.save("vocal.wav", vocal, args.sample_rate)
160
+
161
+ inst = output_ids[1::2]
162
+ inst = torch.tensor([[inst]]).cuda()
163
+ inst = inst.permute(1, 0, 2)
164
+ inst = codec.decode(inst)
165
+ inst = inst.squeeze(0).cpu()
166
+ torchaudio.save("inst.wav", inst, args.sample_rate)
167
+
168
+ print(f"Decoded audio in {timer.interval:.2f} seconds.")
169
+ ```