Update handler.py
Browse files- handler.py +67 -196
handler.py
CHANGED
@@ -1,59 +1,16 @@
|
|
1 |
-
from typing import Dict, Any
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
import json
|
5 |
-
import time
|
6 |
-
import gc
|
7 |
-
from threading import Lock
|
8 |
-
import os
|
9 |
|
10 |
class EndpointHandler:
|
11 |
def __init__(self, path=""):
|
12 |
-
"""Initialize model on startup
|
13 |
try:
|
14 |
from audiocraft.models import MusicGen
|
15 |
|
16 |
-
# Configure PyTorch for better GPU performance
|
17 |
-
torch.backends.cudnn.benchmark = True # Enable cuDNN auto-tuner
|
18 |
-
torch.backends.cuda.matmul.allow_tf32 = True # Allow TF32 on Ampere+ GPUs
|
19 |
-
torch.backends.cudnn.allow_tf32 = True # Allow TF32 for cuDNN
|
20 |
-
|
21 |
-
# Set environment variables for better GPU performance
|
22 |
-
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" # Optimize CUDA connections
|
23 |
-
|
24 |
# Load model - using melody model which supports text and melody inputs
|
25 |
-
# Can also use 'facebook/musicgen-large' for even higher quality at more GPU usage
|
26 |
self.model = MusicGen.get_pretrained('melody')
|
27 |
-
|
28 |
-
# Get GPU memory stats before model optimization
|
29 |
-
gpu_mem_before = torch.cuda.memory_allocated() / (1024 ** 3) # in GB
|
30 |
-
print(f"GPU memory used before optimization: {gpu_mem_before:.2f} GB")
|
31 |
-
|
32 |
-
# Optimize model for inference
|
33 |
-
self.model.eval() # Set model to evaluation mode
|
34 |
-
|
35 |
-
# Use mixed precision for faster inference
|
36 |
-
self.fp16_mode = True
|
37 |
-
if self.fp16_mode:
|
38 |
-
self.model = self.model.half() # Convert to FP16 for faster inference
|
39 |
-
print("Model converted to FP16 for faster inference")
|
40 |
-
|
41 |
-
# Optional: Use torch.compile() for PyTorch 2.0+ (significant speedup)
|
42 |
-
try:
|
43 |
-
if hasattr(torch, 'compile'):
|
44 |
-
# For PyTorch 2.0+, enable torch.compile for faster inference
|
45 |
-
self.model.lm = torch.compile(self.model.lm, mode="reduce-overhead")
|
46 |
-
print("Using torch.compile() for optimized inference")
|
47 |
-
except Exception as compile_error:
|
48 |
-
print(f"Warning: torch.compile optimization failed: {compile_error}")
|
49 |
-
|
50 |
-
# Cache the model on GPU
|
51 |
-
self.model = self.model.cuda()
|
52 |
-
|
53 |
-
# Apply CUDA graph optimization for repeated workloads of the same size
|
54 |
-
self.use_cuda_graphs = False # Enable if generating fixed-size outputs repeatedly
|
55 |
-
|
56 |
-
# Track model sample rate
|
57 |
self.sample_rate = self.model.sample_rate
|
58 |
|
59 |
# Set default generation parameters
|
@@ -62,57 +19,13 @@ class EndpointHandler:
|
|
62 |
top_k=250,
|
63 |
duration=30 # Default segment length
|
64 |
)
|
65 |
-
|
66 |
-
# Create a batch processing queue for multiple requests
|
67 |
-
self.batch_size = 1 # Can be increased for batch processing
|
68 |
-
self.request_lock = Lock() # Lock for thread safety
|
69 |
-
|
70 |
-
# Get GPU memory after optimization
|
71 |
-
torch.cuda.synchronize() # Ensure GPU operations are complete
|
72 |
-
gpu_mem_after = torch.cuda.memory_allocated() / (1024 ** 3) # in GB
|
73 |
-
print(f"GPU memory used after optimization: {gpu_mem_after:.2f} GB")
|
74 |
-
print(f"Additional memory used: {gpu_mem_after - gpu_mem_before:.2f} GB")
|
75 |
-
|
76 |
-
# Warm up the model with a dummy forward pass
|
77 |
-
self._warmup_model()
|
78 |
-
|
79 |
except Exception as e:
|
80 |
# Keep critical error logging only
|
81 |
print(f"CRITICAL: Failed to initialize model: {e}")
|
82 |
raise
|
83 |
|
84 |
-
def _warmup_model(self):
|
85 |
-
"""Perform a warm-up inference to initialize CUDA kernels"""
|
86 |
-
try:
|
87 |
-
print("Warming up model with dummy inference...")
|
88 |
-
start_time = time.time()
|
89 |
-
|
90 |
-
# Create a simple prompt for warm-up
|
91 |
-
dummy_prompt = ["warm up"]
|
92 |
-
|
93 |
-
# Set minimal duration for warm-up
|
94 |
-
self.model.set_generation_params(duration=1)
|
95 |
-
|
96 |
-
# Run inference with dummy input
|
97 |
-
with torch.cuda.amp.autocast(enabled=self.fp16_mode):
|
98 |
-
with torch.no_grad():
|
99 |
-
_ = self.model.generate(dummy_prompt, progress=False)
|
100 |
-
|
101 |
-
# Synchronize GPU to ensure completion
|
102 |
-
torch.cuda.synchronize()
|
103 |
-
|
104 |
-
# Clear GPU cache after warm-up
|
105 |
-
torch.cuda.empty_cache()
|
106 |
-
|
107 |
-
end_time = time.time()
|
108 |
-
print(f"Model warm-up completed in {end_time - start_time:.2f} seconds")
|
109 |
-
except Exception as e:
|
110 |
-
print(f"Warning: Model warm-up failed: {e}")
|
111 |
-
|
112 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
113 |
-
"""Handle prediction requests
|
114 |
-
start_time = time.time()
|
115 |
-
|
116 |
try:
|
117 |
# Parse request data
|
118 |
inputs = data.get("inputs", {})
|
@@ -138,18 +51,8 @@ class EndpointHandler:
|
|
138 |
if not prompt:
|
139 |
return {"error": "No prompt provided"}
|
140 |
|
141 |
-
# Generate music
|
142 |
-
|
143 |
-
audio_data = self.generate_music(prompt, duration, parameters)
|
144 |
-
|
145 |
-
# Log processing time for monitoring performance
|
146 |
-
end_time = time.time()
|
147 |
-
processing_time = end_time - start_time
|
148 |
-
print(f"Request processed in {processing_time:.2f} seconds (duration: {duration:.1f}s)")
|
149 |
-
|
150 |
-
# Log GPU memory usage for monitoring
|
151 |
-
gpu_mem = torch.cuda.memory_allocated() / (1024 ** 3) # in GB
|
152 |
-
print(f"Current GPU memory usage: {gpu_mem:.2f} GB")
|
153 |
|
154 |
return {
|
155 |
"generated_audio": audio_data.tolist(),
|
@@ -161,17 +64,12 @@ class EndpointHandler:
|
|
161 |
|
162 |
except Exception as e:
|
163 |
print(f"ERROR: Request processing failed: {e}")
|
164 |
-
end_time = time.time()
|
165 |
-
print(f"Failed request took {end_time - start_time:.2f} seconds")
|
166 |
return {"error": str(e)}
|
167 |
|
168 |
def generate_music(self, prompt: str, duration: float, parameters: Dict) -> np.ndarray:
|
169 |
-
"""
|
170 |
-
Generate music with proper continuation for long sequences
|
171 |
-
Optimized for maximum GPU utilization
|
172 |
-
"""
|
173 |
try:
|
174 |
-
# Generation parameters
|
175 |
segment_duration = min(30, duration) # Max segment length (30s)
|
176 |
overlap = 5 # Overlap between segments in seconds
|
177 |
|
@@ -185,10 +83,6 @@ class EndpointHandler:
|
|
185 |
"cfg_coef": parameters.get("cfg_coef", 3.0)
|
186 |
}
|
187 |
|
188 |
-
# Additional parameters for GPU optimization
|
189 |
-
guidance_scale = parameters.get("guidance_scale", generation_params["cfg_coef"])
|
190 |
-
generation_params["cfg_coef"] = guidance_scale # Support both parameter names
|
191 |
-
|
192 |
# Set generation parameters
|
193 |
self.model.set_generation_params(**generation_params)
|
194 |
|
@@ -196,91 +90,68 @@ class EndpointHandler:
|
|
196 |
if isinstance(prompt, str):
|
197 |
prompt = [prompt]
|
198 |
|
199 |
-
#
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
progress=False
|
249 |
-
)
|
250 |
-
torch.cuda.synchronize() # Ensure generation is complete
|
251 |
-
cont_end = time.time()
|
252 |
-
|
253 |
-
# Join segments (removing overlap from first segment)
|
254 |
-
segment = torch.cat([segment[:, :, :-overlap*self.sample_rate], next_segment], 2)
|
255 |
-
|
256 |
-
# Update remaining duration
|
257 |
-
if remaining_duration < segment_duration - overlap:
|
258 |
-
seg_duration = remaining_duration
|
259 |
-
remaining_duration = 0
|
260 |
-
else:
|
261 |
-
seg_duration = segment_duration - overlap
|
262 |
-
remaining_duration -= seg_duration
|
263 |
-
|
264 |
-
segment_count += 1
|
265 |
-
print(f"Segment {segment_count} ({seg_duration:.1f}s) generated in {cont_end - cont_start:.2f} seconds")
|
266 |
-
|
267 |
-
# Trim to exact requested duration if needed
|
268 |
-
max_samples = int(duration * self.sample_rate)
|
269 |
-
if segment.shape[2] > max_samples:
|
270 |
-
segment = segment[:, :, :max_samples]
|
271 |
-
|
272 |
-
# Convert to numpy array
|
273 |
-
audio_data = segment.detach().cpu().float()[0].numpy()
|
274 |
-
|
275 |
-
# Optional memory cleanup after large generations
|
276 |
-
if duration > 60:
|
277 |
-
# Clear GPU cache after large generations
|
278 |
-
torch.cuda.empty_cache()
|
279 |
|
280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
|
282 |
except Exception as e:
|
283 |
print(f"ERROR: Music generation failed: {e}")
|
284 |
-
# Clean up GPU memory on error
|
285 |
-
torch.cuda.empty_cache()
|
286 |
raise
|
|
|
1 |
+
from typing import Dict, Any
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
import json
|
|
|
|
|
|
|
|
|
5 |
|
6 |
class EndpointHandler:
|
7 |
def __init__(self, path=""):
|
8 |
+
"""Initialize model on startup"""
|
9 |
try:
|
10 |
from audiocraft.models import MusicGen
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
# Load model - using melody model which supports text and melody inputs
|
|
|
13 |
self.model = MusicGen.get_pretrained('melody')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
self.sample_rate = self.model.sample_rate
|
15 |
|
16 |
# Set default generation parameters
|
|
|
19 |
top_k=250,
|
20 |
duration=30 # Default segment length
|
21 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
except Exception as e:
|
23 |
# Keep critical error logging only
|
24 |
print(f"CRITICAL: Failed to initialize model: {e}")
|
25 |
raise
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
28 |
+
"""Handle prediction requests"""
|
|
|
|
|
29 |
try:
|
30 |
# Parse request data
|
31 |
inputs = data.get("inputs", {})
|
|
|
51 |
if not prompt:
|
52 |
return {"error": "No prompt provided"}
|
53 |
|
54 |
+
# Generate music
|
55 |
+
audio_data = self.generate_music(prompt, duration, parameters)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
return {
|
58 |
"generated_audio": audio_data.tolist(),
|
|
|
64 |
|
65 |
except Exception as e:
|
66 |
print(f"ERROR: Request processing failed: {e}")
|
|
|
|
|
67 |
return {"error": str(e)}
|
68 |
|
69 |
def generate_music(self, prompt: str, duration: float, parameters: Dict) -> np.ndarray:
|
70 |
+
"""Generate music with proper continuation for long sequences"""
|
|
|
|
|
|
|
71 |
try:
|
72 |
+
# Generation parameters
|
73 |
segment_duration = min(30, duration) # Max segment length (30s)
|
74 |
overlap = 5 # Overlap between segments in seconds
|
75 |
|
|
|
83 |
"cfg_coef": parameters.get("cfg_coef", 3.0)
|
84 |
}
|
85 |
|
|
|
|
|
|
|
|
|
86 |
# Set generation parameters
|
87 |
self.model.set_generation_params(**generation_params)
|
88 |
|
|
|
90 |
if isinstance(prompt, str):
|
91 |
prompt = [prompt]
|
92 |
|
93 |
+
# Generate first segment
|
94 |
+
segment = self.model.generate(prompt, progress=False) # Disabled progress tracking
|
95 |
+
|
96 |
+
# If duration is less than or equal to segment_duration, we're done
|
97 |
+
if duration <= segment_duration:
|
98 |
+
# Trim to exact requested duration if needed
|
99 |
+
max_samples = int(duration * self.sample_rate)
|
100 |
+
if segment.shape[2] > max_samples:
|
101 |
+
segment = segment[:, :, :max_samples]
|
102 |
+
audio_data = segment.detach().cpu().float()[0].numpy()
|
103 |
+
return audio_data
|
104 |
+
|
105 |
+
# Track remaining duration for multi-segment generation
|
106 |
+
remaining_duration = duration - segment_duration + overlap
|
107 |
+
segment_count = 1
|
108 |
+
|
109 |
+
# Continue generating segments until we reach desired duration
|
110 |
+
while remaining_duration > 0:
|
111 |
+
# Adjust segment duration for last segment if needed
|
112 |
+
if remaining_duration < segment_duration - overlap:
|
113 |
+
current_segment_duration = remaining_duration + overlap
|
114 |
+
self.model.set_generation_params(
|
115 |
+
use_sampling=generation_params["use_sampling"],
|
116 |
+
top_k=generation_params["top_k"],
|
117 |
+
top_p=generation_params["top_p"],
|
118 |
+
temperature=generation_params["temperature"],
|
119 |
+
duration=current_segment_duration,
|
120 |
+
cfg_coef=generation_params["cfg_coef"]
|
121 |
+
)
|
122 |
+
|
123 |
+
# Extract last few seconds of current segment for continuation
|
124 |
+
last_seconds = segment[:, :, -overlap*self.sample_rate:]
|
125 |
+
|
126 |
+
# Generate continuation
|
127 |
+
next_segment = self.model.generate_continuation(
|
128 |
+
last_seconds,
|
129 |
+
self.sample_rate,
|
130 |
+
prompt,
|
131 |
+
progress=False # Disabled progress tracking
|
132 |
+
)
|
133 |
+
|
134 |
+
# Join segments (removing overlap from first segment)
|
135 |
+
segment = torch.cat([segment[:, :, :-overlap*self.sample_rate], next_segment], 2)
|
136 |
+
|
137 |
+
# Update remaining duration
|
138 |
+
if remaining_duration < segment_duration - overlap:
|
139 |
+
remaining_duration = 0
|
140 |
+
else:
|
141 |
+
remaining_duration -= (segment_duration - overlap)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
+
segment_count += 1
|
144 |
+
|
145 |
+
# Trim to exact requested duration if needed
|
146 |
+
max_samples = int(duration * self.sample_rate)
|
147 |
+
if segment.shape[2] > max_samples:
|
148 |
+
segment = segment[:, :, :max_samples]
|
149 |
+
|
150 |
+
# Convert to numpy array
|
151 |
+
audio_data = segment.detach().cpu().float()[0].numpy()
|
152 |
+
|
153 |
+
return audio_data
|
154 |
|
155 |
except Exception as e:
|
156 |
print(f"ERROR: Music generation failed: {e}")
|
|
|
|
|
157 |
raise
|