ginipick commited on
Commit
3c1a098
ยท
verified ยท
1 Parent(s): bd2cd71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -31
app.py CHANGED
@@ -6,6 +6,7 @@ import tempfile
6
  import torch
7
  import logging
8
  import numpy as np
 
9
  from concurrent.futures import ThreadPoolExecutor
10
  from functools import lru_cache
11
 
@@ -19,16 +20,48 @@ logging.basicConfig(
19
  ]
20
  )
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # GPU ์„ค์ • ์ตœ์ ํ™”
23
  def optimize_gpu_settings():
24
  if torch.cuda.is_available():
25
- # L40S์— ์ตœ์ ํ™”๋œ ์„ค์ •
26
  torch.backends.cuda.matmul.allow_tf32 = True
27
  torch.backends.cudnn.benchmark = True
28
  torch.backends.cudnn.deterministic = False
29
  torch.backends.cudnn.enabled = True
30
 
31
- # GPU ๋ฉ”๋ชจ๋ฆฌ ์„ค์ •
32
  torch.cuda.empty_cache()
33
  torch.cuda.set_device(0)
34
 
@@ -37,7 +70,6 @@ def optimize_gpu_settings():
37
  else:
38
  logging.warning("GPU not available!")
39
 
40
- # flash-attn ์„ค์น˜ ํ•จ์ˆ˜ ๊ฐœ์„ 
41
  def install_flash_attn():
42
  try:
43
  logging.info("Installing flash-attn...")
@@ -51,26 +83,22 @@ def install_flash_attn():
51
  logging.error(f"Failed to install flash-attn: {e}")
52
  raise
53
 
54
- # ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜
55
  def initialize_system():
56
  optimize_gpu_settings()
57
  install_flash_attn()
58
 
59
  from huggingface_hub import snapshot_download
60
 
61
- # xcodec_mini_infer ํด๋” ์ƒ์„ฑ
62
  folder_path = './inference/xcodec_mini_infer'
63
  os.makedirs(folder_path, exist_ok=True)
64
  logging.info(f"Created folder at: {folder_path}")
65
 
66
- # ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ
67
  snapshot_download(
68
  repo_id="m-a-p/xcodec_mini_infer",
69
  local_dir="./inference/xcodec_mini_infer",
70
  resume_download=True
71
  )
72
 
73
- # inference ๋””๋ ‰ํ† ๋ฆฌ๋กœ ์ด๋™
74
  try:
75
  os.chdir("./inference")
76
  logging.info(f"Working directory changed to: {os.getcwd()}")
@@ -78,7 +106,6 @@ def initialize_system():
78
  logging.error(f"Directory error: {e}")
79
  raise
80
 
81
- # ์บ์‹œ๋ฅผ ํ™œ์šฉํ•œ ํŒŒ์ผ ๊ด€๋ฆฌ
82
  @lru_cache(maxsize=100)
83
  def get_cached_file_path(content_hash, prefix):
84
  return create_temp_file(content_hash, prefix)
@@ -111,9 +138,12 @@ def get_last_mp3_file(output_dir):
111
  mp3_files_with_path.sort(key=os.path.getmtime, reverse=True)
112
  return mp3_files_with_path[0]
113
 
114
- # L40S์— ์ตœ์ ํ™”๋œ ์ถ”๋ก  ํ•จ์ˆ˜
115
  def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
116
  try:
 
 
 
 
117
  # ์ž„์‹œ ํŒŒ์ผ ์ƒ์„ฑ
118
  genre_txt_path = create_temp_file(genre_txt_content, prefix="genre_")
119
  lyrics_txt_path = create_temp_file(lyrics_txt_content, prefix="lyrics_")
@@ -122,21 +152,22 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
122
  os.makedirs(output_dir, exist_ok=True)
123
  empty_output_folder(output_dir)
124
 
125
- # L40S์— ์ตœ์ ํ™”๋œ ๋ช…๋ น์–ด
126
  command = [
127
  "python", "infer.py",
128
- "--stage1_model", "m-a-p/YuE-s1-7B-anneal-en-cot",
129
  "--stage2_model", "m-a-p/YuE-s2-1B-general",
130
  "--genre_txt", genre_txt_path,
131
  "--lyrics_txt", lyrics_txt_path,
132
  "--run_n_segments", str(num_segments),
133
- "--stage2_batch_size", "8", # L40S์— ๋งž๊ฒŒ ์ฆ๊ฐ€
134
  "--output_dir", output_dir,
135
  "--cuda_idx", "0",
136
- "--max_new_tokens", str(max_new_tokens),
 
137
  "--disable_offload_model",
138
- "--use_flash_attention_2", # Flash Attention 2 ํ™œ์„ฑํ™”
139
- "--bf16" # BF16 ์ •๋ฐ€๋„ ์‚ฌ์šฉ
140
  ]
141
 
142
  # CUDA ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
@@ -177,7 +208,7 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
177
  # Gradio ์ธํ„ฐํŽ˜์ด์Šค
178
  with gr.Blocks() as demo:
179
  with gr.Column():
180
- gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation (L40S Optimized)")
181
  gr.HTML("""
182
  <div style="display:flex;column-gap:4px;">
183
  <a href="https://github.com/multimodal-art-projection/YuE">
@@ -196,7 +227,7 @@ with gr.Blocks() as demo:
196
  placeholder="Enter music genre and style descriptions..."
197
  )
198
  lyrics_txt = gr.Textbox(
199
- label="Lyrics",
200
  placeholder="Enter song lyrics...",
201
  lines=10
202
  )
@@ -213,7 +244,7 @@ with gr.Blocks() as demo:
213
  max_new_tokens = gr.Slider(
214
  label="Max New Tokens",
215
  minimum=500,
216
- maximum=32000, # L40S์˜ ํฐ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํ™œ์šฉ
217
  step=500,
218
  value=4000,
219
  interactive=True
@@ -221,8 +252,10 @@ with gr.Blocks() as demo:
221
  submit_btn = gr.Button("Generate Music", variant="primary")
222
  music_out = gr.Audio(label="Generated Audio")
223
 
 
224
  gr.Examples(
225
  examples=[
 
226
  [
227
  "female blues airy vocal bright vocal piano sad romantic guitar jazz",
228
  """[verse]
@@ -238,23 +271,52 @@ Can't imagine life alone, don't want to let you go
238
  Stay with me forever, let our love just flow
239
  """
240
  ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  [
242
- "rap piano street tough piercing vocal hip-hop synthesizer clear vocal male",
243
  """[verse]
244
- Woke up in the morning, sun is shining bright
245
- Chasing all my dreams, gotta get my mind right
246
- City lights are fading, but my vision's clear
247
- Got my team beside me, no room for fear
248
- Walking through the streets, beats inside my head
249
- Every step I take, closer to the bread
250
- People passing by, they don't understand
251
- Building up my future with my own two hands
252
 
253
  [chorus]
254
- This is my life, and I'm aiming for the top
255
- Never gonna quit, no, I'm never gonna stop
256
- Through the highs and lows, I'mma keep it real
257
- Living out my dreams with this mic and a deal
258
  """
259
  ]
260
  ],
 
6
  import torch
7
  import logging
8
  import numpy as np
9
+ import re
10
  from concurrent.futures import ThreadPoolExecutor
11
  from functools import lru_cache
12
 
 
20
  ]
21
  )
22
 
23
+ # ์–ธ์–ด ๊ฐ์ง€ ๋ฐ ๋ชจ๋ธ ์„ ํƒ ํ•จ์ˆ˜
24
+ def detect_and_select_model(text):
25
+ if re.search(r'[\u3131-\u318E\uAC00-\uD7A3]', text): # ํ•œ๊ธ€
26
+ return "m-a-p/YuE-s1-7B-anneal-jp-kr-cot"
27
+ elif re.search(r'[\u4e00-\u9fff]', text): # ์ค‘๊ตญ์–ด
28
+ return "m-a-p/YuE-s1-7B-anneal-zh-cot"
29
+ elif re.search(r'[\u3040-\u309F\u30A0-\u30FF]', text): # ์ผ๋ณธ์–ด
30
+ return "m-a-p/YuE-s1-7B-anneal-jp-kr-cot"
31
+ else: # ์˜์–ด/๊ธฐํƒ€
32
+ return "m-a-p/YuE-s1-7B-anneal-en-cot"
33
+
34
+ def optimize_model_selection(lyrics, genre):
35
+ model_path = detect_and_select_model(lyrics)
36
+
37
+ model_config = {
38
+ "m-a-p/YuE-s1-7B-anneal-en-cot": {
39
+ "max_tokens": 24000,
40
+ "temperature": 0.8,
41
+ "batch_size": 8
42
+ },
43
+ "m-a-p/YuE-s1-7B-anneal-jp-kr-cot": {
44
+ "max_tokens": 24000,
45
+ "temperature": 0.7,
46
+ "batch_size": 8
47
+ },
48
+ "m-a-p/YuE-s1-7B-anneal-zh-cot": {
49
+ "max_tokens": 24000,
50
+ "temperature": 0.7,
51
+ "batch_size": 8
52
+ }
53
+ }
54
+
55
+ return model_path, model_config[model_path]
56
+
57
  # GPU ์„ค์ • ์ตœ์ ํ™”
58
  def optimize_gpu_settings():
59
  if torch.cuda.is_available():
 
60
  torch.backends.cuda.matmul.allow_tf32 = True
61
  torch.backends.cudnn.benchmark = True
62
  torch.backends.cudnn.deterministic = False
63
  torch.backends.cudnn.enabled = True
64
 
 
65
  torch.cuda.empty_cache()
66
  torch.cuda.set_device(0)
67
 
 
70
  else:
71
  logging.warning("GPU not available!")
72
 
 
73
  def install_flash_attn():
74
  try:
75
  logging.info("Installing flash-attn...")
 
83
  logging.error(f"Failed to install flash-attn: {e}")
84
  raise
85
 
 
86
  def initialize_system():
87
  optimize_gpu_settings()
88
  install_flash_attn()
89
 
90
  from huggingface_hub import snapshot_download
91
 
 
92
  folder_path = './inference/xcodec_mini_infer'
93
  os.makedirs(folder_path, exist_ok=True)
94
  logging.info(f"Created folder at: {folder_path}")
95
 
 
96
  snapshot_download(
97
  repo_id="m-a-p/xcodec_mini_infer",
98
  local_dir="./inference/xcodec_mini_infer",
99
  resume_download=True
100
  )
101
 
 
102
  try:
103
  os.chdir("./inference")
104
  logging.info(f"Working directory changed to: {os.getcwd()}")
 
106
  logging.error(f"Directory error: {e}")
107
  raise
108
 
 
109
  @lru_cache(maxsize=100)
110
  def get_cached_file_path(content_hash, prefix):
111
  return create_temp_file(content_hash, prefix)
 
138
  mp3_files_with_path.sort(key=os.path.getmtime, reverse=True)
139
  return mp3_files_with_path[0]
140
 
 
141
  def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
142
  try:
143
+ # ๋ชจ๋ธ ์„ ํƒ ๋ฐ ์„ค์ •
144
+ model_path, config = optimize_model_selection(lyrics_txt_content, genre_txt_content)
145
+ logging.info(f"Selected model: {model_path}")
146
+
147
  # ์ž„์‹œ ํŒŒ์ผ ์ƒ์„ฑ
148
  genre_txt_path = create_temp_file(genre_txt_content, prefix="genre_")
149
  lyrics_txt_path = create_temp_file(lyrics_txt_content, prefix="lyrics_")
 
152
  os.makedirs(output_dir, exist_ok=True)
153
  empty_output_folder(output_dir)
154
 
155
+ # ๋ช…๋ น์–ด ๊ตฌ์„ฑ
156
  command = [
157
  "python", "infer.py",
158
+ "--stage1_model", model_path,
159
  "--stage2_model", "m-a-p/YuE-s2-1B-general",
160
  "--genre_txt", genre_txt_path,
161
  "--lyrics_txt", lyrics_txt_path,
162
  "--run_n_segments", str(num_segments),
163
+ "--stage2_batch_size", str(config['batch_size']),
164
  "--output_dir", output_dir,
165
  "--cuda_idx", "0",
166
+ "--max_new_tokens", str(config['max_tokens']),
167
+ "--temperature", str(config['temperature']),
168
  "--disable_offload_model",
169
+ "--use_flash_attention_2",
170
+ "--bf16"
171
  ]
172
 
173
  # CUDA ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
 
208
  # Gradio ์ธํ„ฐํŽ˜์ด์Šค
209
  with gr.Blocks() as demo:
210
  with gr.Column():
211
+ gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation (Multi-Language Support)")
212
  gr.HTML("""
213
  <div style="display:flex;column-gap:4px;">
214
  <a href="https://github.com/multimodal-art-projection/YuE">
 
227
  placeholder="Enter music genre and style descriptions..."
228
  )
229
  lyrics_txt = gr.Textbox(
230
+ label="Lyrics (Supports English, Korean, Japanese, Chinese)",
231
  placeholder="Enter song lyrics...",
232
  lines=10
233
  )
 
244
  max_new_tokens = gr.Slider(
245
  label="Max New Tokens",
246
  minimum=500,
247
+ maximum=32000,
248
  step=500,
249
  value=4000,
250
  interactive=True
 
252
  submit_btn = gr.Button("Generate Music", variant="primary")
253
  music_out = gr.Audio(label="Generated Audio")
254
 
255
+ # ๋‹ค๊ตญ์–ด ์˜ˆ์ œ ์ถ”๊ฐ€
256
  gr.Examples(
257
  examples=[
258
+ # ์˜์–ด ์˜ˆ์ œ
259
  [
260
  "female blues airy vocal bright vocal piano sad romantic guitar jazz",
261
  """[verse]
 
271
  Stay with me forever, let our love just flow
272
  """
273
  ],
274
+ # ํ•œ๊ตญ์–ด ์˜ˆ์ œ
275
+ [
276
+ "K-pop bright energetic synth dance electronic",
277
+ """[verse]
278
+ ๋น›๋‚˜๋Š” ๋ณ„๋“ค์ฒ˜๋Ÿผ ์šฐ๋ฆฌ์˜ ๊ฟˆ์ด
279
+ ์ € ํ•˜๋Š˜์„ ์ˆ˜๋†“์•„ ๋ฐ˜์ง์ด๋„ค
280
+ ํ•จ๊ป˜๋ผ๋ฉด ์–ด๋””๋“  ๊ฐˆ ์ˆ˜ ์žˆ์–ด
281
+ ์šฐ๋ฆฌ์˜ ์ด์•ผ๊ธฐ๊ฐ€ ์‹œ์ž‘๋˜๋„ค
282
+
283
+ [chorus]
284
+ ๋‹ฌ๋ ค๊ฐ€์ž ๋” ๋†’์ด ๋” ๋ฉ€๋ฆฌ
285
+ ๋‘๋ ค์›€์€ ์—†์–ด ๋„ˆ์™€ ํ•จ๊ป˜๋ผ๋ฉด
286
+ ์˜์›ํžˆ ๊ณ„์†๋  ์šฐ๋ฆฌ์˜ ๋…ธ๋ž˜
287
+ ์ด ์ˆœ๊ฐ„์„ ๊ธฐ์–ตํ•ด forever
288
+ """
289
+ ],
290
+ # ์ผ๋ณธ์–ด ์˜ˆ์ œ
291
+ [
292
+ "J-pop melodic soft piano emotional",
293
+ """[verse]
294
+ ๆ˜ฅใฎ้ขจใซไน—ใฃใฆ
295
+ ๆ€ใ„ๅ‡บใŒๆตใ‚Œใ‚‹
296
+ ใ‚ใฎๆ—ฅใฎ็ด„ๆŸใ‚’
297
+ ไปŠใงใ‚‚่ฆšใˆใฆใ‚‹
298
+
299
+ [chorus]
300
+ ๅ›ใจ่ฆ‹ใŸ็ฉบใฏ
301
+ ไปŠใ‚‚ๅค‰ใ‚ใ‚‰ใชใ„ใ‚ˆ
302
+ ใฉใ“ใพใงใ‚‚็ถšใ
303
+ ใ“ใฎ้“ใฎๅ…ˆใง
304
+ """
305
+ ],
306
+ # ์ค‘๊ตญ์–ด ์˜ˆ์ œ
307
  [
308
+ "Chinese pop traditional fusion modern",
309
  """[verse]
310
+ ๆ™จๅ…‰็…งไบฎๅคฉ้™…
311
+ ๆ–ฐ็š„ไธ€ๅคฉๅผ€ๅง‹
312
+ ่ฟฝ้€็€ๆขฆๆƒณๅ‰่ฟ›
313
+ ไธๅœๆญ‡็š„่„šๆญฅ
 
 
 
 
314
 
315
  [chorus]
316
+ ่ฎฉๅธŒๆœ›็…งไบฎๅ‰ๆ–น
317
+ ่ฎฉๅ‹‡ๆฐ”ไผด้š่บซๆ—
318
+ ่ฟ™ไธ€่ทฏๆœ‰ไฝ ็›ธไผด
319
+ ๆฐธ่ฟœไธไผšๅญคๅ•
320
  """
321
  ]
322
  ],