ginipick commited on
Commit
bd2cd71
ยท
verified ยท
1 Parent(s): eb53fd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -151
app.py CHANGED
@@ -3,176 +3,181 @@ import subprocess
3
  import os
4
  import shutil
5
  import tempfile
 
 
 
 
 
6
 
7
- # Install required package
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def install_flash_attn():
9
  try:
10
- print("Installing flash-attn...")
11
  subprocess.run(
12
  ["pip", "install", "flash-attn", "--no-build-isolation"],
13
- check=True
 
14
  )
15
- print("flash-attn installed successfully!")
16
  except subprocess.CalledProcessError as e:
17
- print(f"Failed to install flash-attn: {e}")
18
- exit(1)
19
-
20
- # Install flash-attn
21
- install_flash_attn()
22
 
23
- from huggingface_hub import snapshot_download
24
-
25
- # Create xcodec_mini_infer folder
26
- folder_path = './inference/xcodec_mini_infer'
 
 
 
 
 
 
 
27
 
28
- # Create the folder if it doesn't exist
29
- if not os.path.exists(folder_path):
30
- os.mkdir(folder_path)
31
- print(f"Folder created at: {folder_path}")
32
- else:
33
- print(f"Folder already exists at: {folder_path}")
34
 
35
- snapshot_download(
36
- repo_id = "m-a-p/xcodec_mini_infer",
37
- local_dir = "./inference/xcodec_mini_infer"
38
- )
 
 
 
39
 
40
- # Change to the "inference" directory
41
- inference_dir = "./inference"
42
- try:
43
- os.chdir(inference_dir)
44
- print(f"Changed working directory to: {os.getcwd()}")
45
- except FileNotFoundError:
46
- print(f"Directory not found: {inference_dir}")
47
- exit(1)
48
 
49
  def empty_output_folder(output_dir):
50
- # List all files in the output directory
51
- files = os.listdir(output_dir)
52
-
53
- # Iterate over the files and remove them
54
- for file in files:
55
- file_path = os.path.join(output_dir, file)
56
- try:
57
- if os.path.isdir(file_path):
58
- # If it's a directory, remove it recursively
59
- shutil.rmtree(file_path)
60
- else:
61
- # If it's a file, delete it
62
- os.remove(file_path)
63
- except Exception as e:
64
- print(f"Error deleting file {file_path}: {e}")
65
-
66
- # Function to create a temporary file with string content
67
  def create_temp_file(content, prefix, suffix=".txt"):
68
  temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix)
69
- # Ensure content ends with newline and normalize line endings
70
- content = content.strip() + "\n\n" # Add extra newline at end
71
  content = content.replace("\r\n", "\n").replace("\r", "\n")
72
  temp_file.write(content)
73
  temp_file.close()
74
-
75
- # Debug: Print file contents
76
- print(f"\nContent written to {prefix}{suffix}:")
77
- print(content)
78
- print("---")
79
-
80
  return temp_file.name
81
 
82
  def get_last_mp3_file(output_dir):
83
- # List all files in the output directory
84
- files = os.listdir(output_dir)
85
-
86
- # Filter only .mp3 files
87
- mp3_files = [file for file in files if file.endswith('.mp3')]
88
-
89
  if not mp3_files:
90
- print("No .mp3 files found in the output folder.")
91
  return None
92
 
93
- # Get the full path for the mp3 files
94
- mp3_files_with_path = [os.path.join(output_dir, file) for file in mp3_files]
95
-
96
- # Sort the files based on the modification time (most recent first)
97
- mp3_files_with_path.sort(key=lambda x: os.path.getmtime(x), reverse=True)
98
-
99
- # Return the most recent .mp3 file
100
  return mp3_files_with_path[0]
101
 
 
102
  def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
103
- # Create temporary files
104
- genre_txt_path = create_temp_file(genre_txt_content, prefix="genre_")
105
- lyrics_txt_path = create_temp_file(lyrics_txt_content, prefix="lyrics_")
106
-
107
- print(f"Genre TXT path: {genre_txt_path}")
108
- print(f"Lyrics TXT path: {lyrics_txt_path}")
109
-
110
- # Ensure the output folder exists
111
- output_dir = "./output"
112
- os.makedirs(output_dir, exist_ok=True)
113
- print(f"Output folder ensured at: {output_dir}")
114
-
115
- empty_output_folder(output_dir)
116
-
117
- # Command and arguments with optimized settings
118
- command = [
119
- "python", "infer.py",
120
- "--stage1_model", "m-a-p/YuE-s1-7B-anneal-en-cot",
121
- "--stage2_model", "m-a-p/YuE-s2-1B-general",
122
- "--genre_txt", f"{genre_txt_path}",
123
- "--lyrics_txt", f"{lyrics_txt_path}",
124
- "--run_n_segments", f"{num_segments}",
125
- "--stage2_batch_size", "4",
126
- "--output_dir", f"{output_dir}",
127
- "--cuda_idx", "0",
128
- "--max_new_tokens", f"{max_new_tokens}",
129
- "--disable_offload_model"
130
- ]
131
-
132
- # Set up environment variables for CUDA with optimized settings
133
- env = os.environ.copy()
134
- env.update({
135
- "CUDA_VISIBLE_DEVICES": "0",
136
- "CUDA_HOME": "/usr/local/cuda",
137
- "PATH": f"/usr/local/cuda/bin:{env.get('PATH', '')}",
138
- "LD_LIBRARY_PATH": f"/usr/local/cuda/lib64:{env.get('LD_LIBRARY_PATH', '')}"
139
- })
140
-
141
- # Execute the command
142
  try:
143
- subprocess.run(command, check=True, env=env)
144
- print("Command executed successfully!")
 
145
 
146
- # Check and print the contents of the output folder
147
- output_files = os.listdir(output_dir)
148
- if output_files:
149
- print("Output folder contents:")
150
- for file in output_files:
151
- print(f"- {file}")
152
-
153
- last_mp3 = get_last_mp3_file(output_dir)
154
-
155
- if last_mp3:
156
- print("Last .mp3 file:", last_mp3)
157
- return last_mp3
158
- else:
159
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  else:
161
- print("Output folder is empty.")
162
  return None
163
- except subprocess.CalledProcessError as e:
164
- print(f"Error occurred: {e}")
165
- return None
 
166
  finally:
167
- # Clean up temporary files
168
- os.remove(genre_txt_path)
169
- os.remove(lyrics_txt_path)
170
- print("Temporary files deleted.")
 
 
 
171
 
172
- # Gradio
173
  with gr.Blocks() as demo:
174
  with gr.Column():
175
- gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation")
176
  gr.HTML("""
177
  <div style="display:flex;column-gap:4px;">
178
  <a href="https://github.com/multimodal-art-projection/YuE">
@@ -181,24 +186,43 @@ with gr.Blocks() as demo:
181
  <a href="https://map-yue.github.io">
182
  <img src='https://img.shields.io/badge/Project-Page-green'>
183
  </a>
184
- <a href="https://huggingface.co/spaces/fffiloni/YuE?duplicate=true">
185
- <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
186
- </a>
187
  </div>
188
  """)
 
189
  with gr.Row():
190
  with gr.Column():
191
- genre_txt = gr.Textbox(label="Genre")
192
- lyrics_txt = gr.Textbox(label="Lyrics")
 
 
 
 
 
 
 
193
 
194
  with gr.Column():
195
- num_segments = gr.Number(label="Number of Song Segments", value=2, interactive=True)
196
- max_new_tokens = gr.Slider(label="Max New Tokens", minimum=500, maximum=24000, step=500, value=3000, interactive=True)
197
- submit_btn = gr.Button("Submit")
198
- music_out = gr.Audio(label="Audio Result")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
  gr.Examples(
201
- examples = [
202
  [
203
  "female blues airy vocal bright vocal piano sad romantic guitar jazz",
204
  """[verse]
@@ -233,13 +257,26 @@ Through the highs and lows, I'mma keep it real
233
  Living out my dreams with this mic and a deal
234
  """
235
  ]
236
- ],
237
- inputs = [genre_txt, lyrics_txt]
238
  )
 
 
 
239
 
 
240
  submit_btn.click(
241
- fn = infer,
242
- inputs = [genre_txt, lyrics_txt, num_segments, max_new_tokens],
243
- outputs = [music_out]
244
  )
245
- demo.queue().launch(show_api=True, show_error=True)
 
 
 
 
 
 
 
 
 
 
3
  import os
4
  import shutil
5
  import tempfile
6
+ import torch
7
+ import logging
8
+ import numpy as np
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ from functools import lru_cache
11
 
12
+ # ๋กœ๊น… ์„ค์ •
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format='%(asctime)s - %(levelname)s - %(message)s',
16
+ handlers=[
17
+ logging.FileHandler('yue_generation.log'),
18
+ logging.StreamHandler()
19
+ ]
20
+ )
21
+
22
+ # GPU ์„ค์ • ์ตœ์ ํ™”
23
+ def optimize_gpu_settings():
24
+ if torch.cuda.is_available():
25
+ # L40S์— ์ตœ์ ํ™”๋œ ์„ค์ •
26
+ torch.backends.cuda.matmul.allow_tf32 = True
27
+ torch.backends.cudnn.benchmark = True
28
+ torch.backends.cudnn.deterministic = False
29
+ torch.backends.cudnn.enabled = True
30
+
31
+ # GPU ๋ฉ”๋ชจ๋ฆฌ ์„ค์ •
32
+ torch.cuda.empty_cache()
33
+ torch.cuda.set_device(0)
34
+
35
+ logging.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
36
+ logging.info(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
37
+ else:
38
+ logging.warning("GPU not available!")
39
+
40
+ # flash-attn ์„ค์น˜ ํ•จ์ˆ˜ ๊ฐœ์„ 
41
  def install_flash_attn():
42
  try:
43
+ logging.info("Installing flash-attn...")
44
  subprocess.run(
45
  ["pip", "install", "flash-attn", "--no-build-isolation"],
46
+ check=True,
47
+ capture_output=True
48
  )
49
+ logging.info("flash-attn installed successfully!")
50
  except subprocess.CalledProcessError as e:
51
+ logging.error(f"Failed to install flash-attn: {e}")
52
+ raise
 
 
 
53
 
54
+ # ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜
55
+ def initialize_system():
56
+ optimize_gpu_settings()
57
+ install_flash_attn()
58
+
59
+ from huggingface_hub import snapshot_download
60
+
61
+ # xcodec_mini_infer ํด๋” ์ƒ์„ฑ
62
+ folder_path = './inference/xcodec_mini_infer'
63
+ os.makedirs(folder_path, exist_ok=True)
64
+ logging.info(f"Created folder at: {folder_path}")
65
 
66
+ # ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ
67
+ snapshot_download(
68
+ repo_id="m-a-p/xcodec_mini_infer",
69
+ local_dir="./inference/xcodec_mini_infer",
70
+ resume_download=True
71
+ )
72
 
73
+ # inference ๋””๋ ‰ํ† ๋ฆฌ๋กœ ์ด๋™
74
+ try:
75
+ os.chdir("./inference")
76
+ logging.info(f"Working directory changed to: {os.getcwd()}")
77
+ except FileNotFoundError as e:
78
+ logging.error(f"Directory error: {e}")
79
+ raise
80
 
81
+ # ์บ์‹œ๋ฅผ ํ™œ์šฉํ•œ ํŒŒ์ผ ๊ด€๋ฆฌ
82
+ @lru_cache(maxsize=100)
83
+ def get_cached_file_path(content_hash, prefix):
84
+ return create_temp_file(content_hash, prefix)
 
 
 
 
85
 
86
  def empty_output_folder(output_dir):
87
+ try:
88
+ shutil.rmtree(output_dir)
89
+ os.makedirs(output_dir)
90
+ logging.info(f"Output folder cleaned: {output_dir}")
91
+ except Exception as e:
92
+ logging.error(f"Error cleaning output folder: {e}")
93
+ raise
94
+
 
 
 
 
 
 
 
 
 
95
  def create_temp_file(content, prefix, suffix=".txt"):
96
  temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix)
97
+ content = content.strip() + "\n\n"
 
98
  content = content.replace("\r\n", "\n").replace("\r", "\n")
99
  temp_file.write(content)
100
  temp_file.close()
101
+ logging.debug(f"Temporary file created: {temp_file.name}")
 
 
 
 
 
102
  return temp_file.name
103
 
104
  def get_last_mp3_file(output_dir):
105
+ mp3_files = [f for f in os.listdir(output_dir) if f.endswith('.mp3')]
 
 
 
 
 
106
  if not mp3_files:
107
+ logging.warning("No MP3 files found")
108
  return None
109
 
110
+ mp3_files_with_path = [os.path.join(output_dir, f) for f in mp3_files]
111
+ mp3_files_with_path.sort(key=os.path.getmtime, reverse=True)
 
 
 
 
 
112
  return mp3_files_with_path[0]
113
 
114
+ # L40S์— ์ตœ์ ํ™”๋œ ์ถ”๋ก  ํ•จ์ˆ˜
115
  def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  try:
117
+ # ์ž„์‹œ ํŒŒ์ผ ์ƒ์„ฑ
118
+ genre_txt_path = create_temp_file(genre_txt_content, prefix="genre_")
119
+ lyrics_txt_path = create_temp_file(lyrics_txt_content, prefix="lyrics_")
120
 
121
+ output_dir = "./output"
122
+ os.makedirs(output_dir, exist_ok=True)
123
+ empty_output_folder(output_dir)
124
+
125
+ # L40S์— ์ตœ์ ํ™”๋œ ๋ช…๋ น์–ด
126
+ command = [
127
+ "python", "infer.py",
128
+ "--stage1_model", "m-a-p/YuE-s1-7B-anneal-en-cot",
129
+ "--stage2_model", "m-a-p/YuE-s2-1B-general",
130
+ "--genre_txt", genre_txt_path,
131
+ "--lyrics_txt", lyrics_txt_path,
132
+ "--run_n_segments", str(num_segments),
133
+ "--stage2_batch_size", "8", # L40S์— ๋งž๊ฒŒ ์ฆ๊ฐ€
134
+ "--output_dir", output_dir,
135
+ "--cuda_idx", "0",
136
+ "--max_new_tokens", str(max_new_tokens),
137
+ "--disable_offload_model",
138
+ "--use_flash_attention_2", # Flash Attention 2 ํ™œ์„ฑํ™”
139
+ "--bf16" # BF16 ์ •๋ฐ€๋„ ์‚ฌ์šฉ
140
+ ]
141
+
142
+ # CUDA ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
143
+ env = os.environ.copy()
144
+ env.update({
145
+ "CUDA_VISIBLE_DEVICES": "0",
146
+ "CUDA_HOME": "/usr/local/cuda",
147
+ "PATH": f"/usr/local/cuda/bin:{env.get('PATH', '')}",
148
+ "LD_LIBRARY_PATH": f"/usr/local/cuda/lib64:{env.get('LD_LIBRARY_PATH', '')}",
149
+ "PYTORCH_CUDA_ALLOC_CONF": "max_split_size_mb:512"
150
+ })
151
+
152
+ # ๋ช…๋ น ์‹คํ–‰
153
+ process = subprocess.run(command, env=env, check=True, capture_output=True)
154
+ logging.info("Inference completed successfully")
155
+
156
+ # ๊ฒฐ๊ณผ ์ฒ˜๋ฆฌ
157
+ last_mp3 = get_last_mp3_file(output_dir)
158
+ if last_mp3:
159
+ logging.info(f"Generated audio file: {last_mp3}")
160
+ return last_mp3
161
  else:
162
+ logging.warning("No output audio file generated")
163
  return None
164
+
165
+ except Exception as e:
166
+ logging.error(f"Inference error: {e}")
167
+ raise
168
  finally:
169
+ # ์ž„์‹œ ํŒŒ์ผ ์ •๋ฆฌ
170
+ for file in [genre_txt_path, lyrics_txt_path]:
171
+ try:
172
+ os.remove(file)
173
+ logging.debug(f"Removed temporary file: {file}")
174
+ except Exception as e:
175
+ logging.warning(f"Failed to remove temporary file {file}: {e}")
176
 
177
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค
178
  with gr.Blocks() as demo:
179
  with gr.Column():
180
+ gr.Markdown("# YuE: Open Music Foundation Models for Full-Song Generation (L40S Optimized)")
181
  gr.HTML("""
182
  <div style="display:flex;column-gap:4px;">
183
  <a href="https://github.com/multimodal-art-projection/YuE">
 
186
  <a href="https://map-yue.github.io">
187
  <img src='https://img.shields.io/badge/Project-Page-green'>
188
  </a>
 
 
 
189
  </div>
190
  """)
191
+
192
  with gr.Row():
193
  with gr.Column():
194
+ genre_txt = gr.Textbox(
195
+ label="Genre",
196
+ placeholder="Enter music genre and style descriptions..."
197
+ )
198
+ lyrics_txt = gr.Textbox(
199
+ label="Lyrics",
200
+ placeholder="Enter song lyrics...",
201
+ lines=10
202
+ )
203
 
204
  with gr.Column():
205
+ num_segments = gr.Number(
206
+ label="Number of Song Segments",
207
+ value=2,
208
+ minimum=1,
209
+ maximum=4,
210
+ step=1,
211
+ interactive=True
212
+ )
213
+ max_new_tokens = gr.Slider(
214
+ label="Max New Tokens",
215
+ minimum=500,
216
+ maximum=32000, # L40S์˜ ํฐ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํ™œ์šฉ
217
+ step=500,
218
+ value=4000,
219
+ interactive=True
220
+ )
221
+ submit_btn = gr.Button("Generate Music", variant="primary")
222
+ music_out = gr.Audio(label="Generated Audio")
223
 
224
  gr.Examples(
225
+ examples=[
226
  [
227
  "female blues airy vocal bright vocal piano sad romantic guitar jazz",
228
  """[verse]
 
257
  Living out my dreams with this mic and a deal
258
  """
259
  ]
260
+ ],
261
+ inputs=[genre_txt, lyrics_txt]
262
  )
263
+
264
+ # ์‹œ์Šคํ…œ ์ดˆ๊ธฐํ™”
265
+ initialize_system()
266
 
267
+ # ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ
268
  submit_btn.click(
269
+ fn=infer,
270
+ inputs=[genre_txt, lyrics_txt, num_segments, max_new_tokens],
271
+ outputs=[music_out]
272
  )
273
+
274
+ # ์„œ๋ฒ„ ์„ค์ •์œผ๋กœ ์‹คํ–‰
275
+ demo.queue(concurrency_count=2).launch(
276
+ server_name="0.0.0.0",
277
+ server_port=7860,
278
+ share=True,
279
+ enable_queue=True,
280
+ show_api=True,
281
+ show_error=True
282
+ )