Phoenixak99 commited on
Commit
c5073cf
·
verified ·
1 Parent(s): ac244c7

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +67 -196
handler.py CHANGED
@@ -1,59 +1,16 @@
1
- from typing import Dict, Any, List, Optional
2
  import torch
3
  import numpy as np
4
  import json
5
- import time
6
- import gc
7
- from threading import Lock
8
- import os
9
 
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
- """Initialize model on startup with optimized GPU usage"""
13
  try:
14
  from audiocraft.models import MusicGen
15
 
16
- # Configure PyTorch for better GPU performance
17
- torch.backends.cudnn.benchmark = True # Enable cuDNN auto-tuner
18
- torch.backends.cuda.matmul.allow_tf32 = True # Allow TF32 on Ampere+ GPUs
19
- torch.backends.cudnn.allow_tf32 = True # Allow TF32 for cuDNN
20
-
21
- # Set environment variables for better GPU performance
22
- os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" # Optimize CUDA connections
23
-
24
  # Load model - using melody model which supports text and melody inputs
25
- # Can also use 'facebook/musicgen-large' for even higher quality at more GPU usage
26
  self.model = MusicGen.get_pretrained('melody')
27
-
28
- # Get GPU memory stats before model optimization
29
- gpu_mem_before = torch.cuda.memory_allocated() / (1024 ** 3) # in GB
30
- print(f"GPU memory used before optimization: {gpu_mem_before:.2f} GB")
31
-
32
- # Optimize model for inference
33
- self.model.eval() # Set model to evaluation mode
34
-
35
- # Use mixed precision for faster inference
36
- self.fp16_mode = True
37
- if self.fp16_mode:
38
- self.model = self.model.half() # Convert to FP16 for faster inference
39
- print("Model converted to FP16 for faster inference")
40
-
41
- # Optional: Use torch.compile() for PyTorch 2.0+ (significant speedup)
42
- try:
43
- if hasattr(torch, 'compile'):
44
- # For PyTorch 2.0+, enable torch.compile for faster inference
45
- self.model.lm = torch.compile(self.model.lm, mode="reduce-overhead")
46
- print("Using torch.compile() for optimized inference")
47
- except Exception as compile_error:
48
- print(f"Warning: torch.compile optimization failed: {compile_error}")
49
-
50
- # Cache the model on GPU
51
- self.model = self.model.cuda()
52
-
53
- # Apply CUDA graph optimization for repeated workloads of the same size
54
- self.use_cuda_graphs = False # Enable if generating fixed-size outputs repeatedly
55
-
56
- # Track model sample rate
57
  self.sample_rate = self.model.sample_rate
58
 
59
  # Set default generation parameters
@@ -62,57 +19,13 @@ class EndpointHandler:
62
  top_k=250,
63
  duration=30 # Default segment length
64
  )
65
-
66
- # Create a batch processing queue for multiple requests
67
- self.batch_size = 1 # Can be increased for batch processing
68
- self.request_lock = Lock() # Lock for thread safety
69
-
70
- # Get GPU memory after optimization
71
- torch.cuda.synchronize() # Ensure GPU operations are complete
72
- gpu_mem_after = torch.cuda.memory_allocated() / (1024 ** 3) # in GB
73
- print(f"GPU memory used after optimization: {gpu_mem_after:.2f} GB")
74
- print(f"Additional memory used: {gpu_mem_after - gpu_mem_before:.2f} GB")
75
-
76
- # Warm up the model with a dummy forward pass
77
- self._warmup_model()
78
-
79
  except Exception as e:
80
  # Keep critical error logging only
81
  print(f"CRITICAL: Failed to initialize model: {e}")
82
  raise
83
 
84
- def _warmup_model(self):
85
- """Perform a warm-up inference to initialize CUDA kernels"""
86
- try:
87
- print("Warming up model with dummy inference...")
88
- start_time = time.time()
89
-
90
- # Create a simple prompt for warm-up
91
- dummy_prompt = ["warm up"]
92
-
93
- # Set minimal duration for warm-up
94
- self.model.set_generation_params(duration=1)
95
-
96
- # Run inference with dummy input
97
- with torch.cuda.amp.autocast(enabled=self.fp16_mode):
98
- with torch.no_grad():
99
- _ = self.model.generate(dummy_prompt, progress=False)
100
-
101
- # Synchronize GPU to ensure completion
102
- torch.cuda.synchronize()
103
-
104
- # Clear GPU cache after warm-up
105
- torch.cuda.empty_cache()
106
-
107
- end_time = time.time()
108
- print(f"Model warm-up completed in {end_time - start_time:.2f} seconds")
109
- except Exception as e:
110
- print(f"Warning: Model warm-up failed: {e}")
111
-
112
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
113
- """Handle prediction requests with optimized GPU processing"""
114
- start_time = time.time()
115
-
116
  try:
117
  # Parse request data
118
  inputs = data.get("inputs", {})
@@ -138,18 +51,8 @@ class EndpointHandler:
138
  if not prompt:
139
  return {"error": "No prompt provided"}
140
 
141
- # Generate music with optimized GPU processing
142
- with self.request_lock: # Ensure thread safety
143
- audio_data = self.generate_music(prompt, duration, parameters)
144
-
145
- # Log processing time for monitoring performance
146
- end_time = time.time()
147
- processing_time = end_time - start_time
148
- print(f"Request processed in {processing_time:.2f} seconds (duration: {duration:.1f}s)")
149
-
150
- # Log GPU memory usage for monitoring
151
- gpu_mem = torch.cuda.memory_allocated() / (1024 ** 3) # in GB
152
- print(f"Current GPU memory usage: {gpu_mem:.2f} GB")
153
 
154
  return {
155
  "generated_audio": audio_data.tolist(),
@@ -161,17 +64,12 @@ class EndpointHandler:
161
 
162
  except Exception as e:
163
  print(f"ERROR: Request processing failed: {e}")
164
- end_time = time.time()
165
- print(f"Failed request took {end_time - start_time:.2f} seconds")
166
  return {"error": str(e)}
167
 
168
  def generate_music(self, prompt: str, duration: float, parameters: Dict) -> np.ndarray:
169
- """
170
- Generate music with proper continuation for long sequences
171
- Optimized for maximum GPU utilization
172
- """
173
  try:
174
- # Generation parameters with performance optimizations
175
  segment_duration = min(30, duration) # Max segment length (30s)
176
  overlap = 5 # Overlap between segments in seconds
177
 
@@ -185,10 +83,6 @@ class EndpointHandler:
185
  "cfg_coef": parameters.get("cfg_coef", 3.0)
186
  }
187
 
188
- # Additional parameters for GPU optimization
189
- guidance_scale = parameters.get("guidance_scale", generation_params["cfg_coef"])
190
- generation_params["cfg_coef"] = guidance_scale # Support both parameter names
191
-
192
  # Set generation parameters
193
  self.model.set_generation_params(**generation_params)
194
 
@@ -196,91 +90,68 @@ class EndpointHandler:
196
  if isinstance(prompt, str):
197
  prompt = [prompt]
198
 
199
- # Use torch.no_grad and autocast for optimized inference
200
- with torch.no_grad():
201
- with torch.cuda.amp.autocast(enabled=self.fp16_mode):
202
- # Generate first segment with timing
203
- segment_start = time.time()
204
- segment = self.model.generate(prompt, progress=False)
205
- torch.cuda.synchronize() # Ensure generation is complete
206
- segment_end = time.time()
207
-
208
- print(f"First segment ({segment_duration}s) generated in {segment_end - segment_start:.2f} seconds")
209
-
210
- # If duration is less than or equal to segment_duration, we're done
211
- if duration <= segment_duration:
212
- # Trim to exact requested duration if needed
213
- max_samples = int(duration * self.sample_rate)
214
- if segment.shape[2] > max_samples:
215
- segment = segment[:, :, :max_samples]
216
- audio_data = segment.detach().cpu().float()[0].numpy()
217
- return audio_data
218
-
219
- # Track remaining duration for multi-segment generation
220
- remaining_duration = duration - segment_duration + overlap
221
- segment_count = 1
222
-
223
- print(f"Multi-segment generation needed. Total segments: ~{1 + int(remaining_duration / (segment_duration - overlap))}")
224
-
225
- # Continue generating segments until we reach desired duration
226
- while remaining_duration > 0:
227
- # Adjust segment duration for last segment if needed
228
- if remaining_duration < segment_duration - overlap:
229
- current_segment_duration = remaining_duration + overlap
230
- self.model.set_generation_params(
231
- use_sampling=generation_params["use_sampling"],
232
- top_k=generation_params["top_k"],
233
- top_p=generation_params["top_p"],
234
- temperature=generation_params["temperature"],
235
- duration=current_segment_duration,
236
- cfg_coef=generation_params["cfg_coef"]
237
- )
238
-
239
- # Extract last few seconds of current segment for continuation
240
- last_seconds = segment[:, :, -overlap*self.sample_rate:]
241
-
242
- # Generate continuation with timing
243
- cont_start = time.time()
244
- next_segment = self.model.generate_continuation(
245
- last_seconds,
246
- self.sample_rate,
247
- prompt,
248
- progress=False
249
- )
250
- torch.cuda.synchronize() # Ensure generation is complete
251
- cont_end = time.time()
252
-
253
- # Join segments (removing overlap from first segment)
254
- segment = torch.cat([segment[:, :, :-overlap*self.sample_rate], next_segment], 2)
255
-
256
- # Update remaining duration
257
- if remaining_duration < segment_duration - overlap:
258
- seg_duration = remaining_duration
259
- remaining_duration = 0
260
- else:
261
- seg_duration = segment_duration - overlap
262
- remaining_duration -= seg_duration
263
-
264
- segment_count += 1
265
- print(f"Segment {segment_count} ({seg_duration:.1f}s) generated in {cont_end - cont_start:.2f} seconds")
266
-
267
- # Trim to exact requested duration if needed
268
- max_samples = int(duration * self.sample_rate)
269
- if segment.shape[2] > max_samples:
270
- segment = segment[:, :, :max_samples]
271
-
272
- # Convert to numpy array
273
- audio_data = segment.detach().cpu().float()[0].numpy()
274
-
275
- # Optional memory cleanup after large generations
276
- if duration > 60:
277
- # Clear GPU cache after large generations
278
- torch.cuda.empty_cache()
279
 
280
- return audio_data
 
 
 
 
 
 
 
 
 
 
281
 
282
  except Exception as e:
283
  print(f"ERROR: Music generation failed: {e}")
284
- # Clean up GPU memory on error
285
- torch.cuda.empty_cache()
286
  raise
 
1
+ from typing import Dict, Any
2
  import torch
3
  import numpy as np
4
  import json
 
 
 
 
5
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
+ """Initialize model on startup"""
9
  try:
10
  from audiocraft.models import MusicGen
11
 
 
 
 
 
 
 
 
 
12
  # Load model - using melody model which supports text and melody inputs
 
13
  self.model = MusicGen.get_pretrained('melody')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  self.sample_rate = self.model.sample_rate
15
 
16
  # Set default generation parameters
 
19
  top_k=250,
20
  duration=30 # Default segment length
21
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  except Exception as e:
23
  # Keep critical error logging only
24
  print(f"CRITICAL: Failed to initialize model: {e}")
25
  raise
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
28
+ """Handle prediction requests"""
 
 
29
  try:
30
  # Parse request data
31
  inputs = data.get("inputs", {})
 
51
  if not prompt:
52
  return {"error": "No prompt provided"}
53
 
54
+ # Generate music
55
+ audio_data = self.generate_music(prompt, duration, parameters)
 
 
 
 
 
 
 
 
 
 
56
 
57
  return {
58
  "generated_audio": audio_data.tolist(),
 
64
 
65
  except Exception as e:
66
  print(f"ERROR: Request processing failed: {e}")
 
 
67
  return {"error": str(e)}
68
 
69
  def generate_music(self, prompt: str, duration: float, parameters: Dict) -> np.ndarray:
70
+ """Generate music with proper continuation for long sequences"""
 
 
 
71
  try:
72
+ # Generation parameters
73
  segment_duration = min(30, duration) # Max segment length (30s)
74
  overlap = 5 # Overlap between segments in seconds
75
 
 
83
  "cfg_coef": parameters.get("cfg_coef", 3.0)
84
  }
85
 
 
 
 
 
86
  # Set generation parameters
87
  self.model.set_generation_params(**generation_params)
88
 
 
90
  if isinstance(prompt, str):
91
  prompt = [prompt]
92
 
93
+ # Generate first segment
94
+ segment = self.model.generate(prompt, progress=False) # Disabled progress tracking
95
+
96
+ # If duration is less than or equal to segment_duration, we're done
97
+ if duration <= segment_duration:
98
+ # Trim to exact requested duration if needed
99
+ max_samples = int(duration * self.sample_rate)
100
+ if segment.shape[2] > max_samples:
101
+ segment = segment[:, :, :max_samples]
102
+ audio_data = segment.detach().cpu().float()[0].numpy()
103
+ return audio_data
104
+
105
+ # Track remaining duration for multi-segment generation
106
+ remaining_duration = duration - segment_duration + overlap
107
+ segment_count = 1
108
+
109
+ # Continue generating segments until we reach desired duration
110
+ while remaining_duration > 0:
111
+ # Adjust segment duration for last segment if needed
112
+ if remaining_duration < segment_duration - overlap:
113
+ current_segment_duration = remaining_duration + overlap
114
+ self.model.set_generation_params(
115
+ use_sampling=generation_params["use_sampling"],
116
+ top_k=generation_params["top_k"],
117
+ top_p=generation_params["top_p"],
118
+ temperature=generation_params["temperature"],
119
+ duration=current_segment_duration,
120
+ cfg_coef=generation_params["cfg_coef"]
121
+ )
122
+
123
+ # Extract last few seconds of current segment for continuation
124
+ last_seconds = segment[:, :, -overlap*self.sample_rate:]
125
+
126
+ # Generate continuation
127
+ next_segment = self.model.generate_continuation(
128
+ last_seconds,
129
+ self.sample_rate,
130
+ prompt,
131
+ progress=False # Disabled progress tracking
132
+ )
133
+
134
+ # Join segments (removing overlap from first segment)
135
+ segment = torch.cat([segment[:, :, :-overlap*self.sample_rate], next_segment], 2)
136
+
137
+ # Update remaining duration
138
+ if remaining_duration < segment_duration - overlap:
139
+ remaining_duration = 0
140
+ else:
141
+ remaining_duration -= (segment_duration - overlap)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ segment_count += 1
144
+
145
+ # Trim to exact requested duration if needed
146
+ max_samples = int(duration * self.sample_rate)
147
+ if segment.shape[2] > max_samples:
148
+ segment = segment[:, :, :max_samples]
149
+
150
+ # Convert to numpy array
151
+ audio_data = segment.detach().cpu().float()[0].numpy()
152
+
153
+ return audio_data
154
 
155
  except Exception as e:
156
  print(f"ERROR: Music generation failed: {e}")
 
 
157
  raise