PyTorch
musicgen
Phoenixak99 commited on
Commit
90cb56c
·
verified ·
1 Parent(s): c7fca87

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +149 -166
handler.py CHANGED
@@ -75,11 +75,11 @@ class EndpointHandler:
75
  logger.info(f"Requested duration: {duration} seconds")
76
 
77
  # Generate audio
78
- if duration <= self.max_segment_duration: # For short durations, generate in one go
79
  audio_output = self._generate_short_audio(prompt, duration, parameters)
80
  else:
81
- # Use sliding window approach for longer durations
82
- audio_output = self._generate_long_audio_sliding_window(prompt, duration, parameters)
83
 
84
  # Monitor GPU memory after generation
85
  allocated = torch.cuda.memory_allocated() / 1e9
@@ -137,8 +137,7 @@ class EndpointHandler:
137
 
138
  # Generate audio
139
  logger.info(f"Generation parameters: {generation_kwargs}")
140
- with torch.inference_mode():
141
- outputs = self.model.generate(**inputs, **generation_kwargs)
142
 
143
  # Return audio
144
  return outputs[0].cpu().numpy()
@@ -157,21 +156,20 @@ class EndpointHandler:
157
  ).to("cuda")
158
 
159
  # Generate with minimal parameters
160
- with torch.inference_mode():
161
- outputs = self.model.generate(
162
- **inputs,
163
- max_new_tokens=max_new_tokens,
164
- do_sample=True,
165
- guidance_scale=1.0 # Minimal guidance
166
- )
167
 
168
  return outputs[0].cpu().numpy()
169
  except Exception as e2:
170
  logger.error(f"Second attempt failed: {e2}")
171
  raise e2
172
 
173
- def _equal_power_crossfade(self, segment1, segment2, overlap_samples):
174
- """Apply an equal-power crossfade between segments for smooth transitions."""
175
  # Get the length of the segments
176
  length1 = segment1.shape[1]
177
  length2 = segment2.shape[1]
@@ -189,12 +187,11 @@ class EndpointHandler:
189
  # Copy the non-overlapping part of segment2
190
  result[:, length1:] = segment2[:, overlap_samples:]
191
 
192
- # Apply equal-power crossfade to the overlapping parts
193
  if overlap_samples > 0:
194
- # Equal power crossfade curves (cosine/sine based for smoother transitions)
195
- t = np.linspace(0, np.pi/2, overlap_samples)
196
- fade_out = np.cos(t)**2
197
- fade_in = np.sin(t)**2
198
 
199
  # Get the overlapping parts
200
  segment1_end = segment1[:, -overlap_samples:].copy()
@@ -213,104 +210,80 @@ class EndpointHandler:
213
 
214
  return result
215
 
216
- def _extract_style_keywords(self, prompt):
217
- """Extract potential style-related keywords from the prompt to emphasize in continuations."""
218
- # Common musical style keywords
219
- style_keywords = [
220
- "rock", "jazz", "classical", "pop", "electronic", "hip-hop", "rap", "country",
221
- "folk", "blues", "metal", "ambient", "orchestral", "indie", "r&b", "soul",
222
- "techno", "house", "drum and bass", "dubstep", "trance", "lo-fi", "lofi", "cinematic",
223
- "soundtrack", "instrumental", "acoustic", "electric", "synth", "piano",
224
- "guitar", "bass", "drums", "violin", "cello", "trumpet", "saxophone"
225
- ]
226
-
227
- # Extract any style keywords from the prompt
228
- prompt_lower = prompt.lower()
229
- found_keywords = []
230
-
231
- for keyword in style_keywords:
232
- if keyword in prompt_lower:
233
- found_keywords.append(keyword)
234
 
235
- # Return a string of found keywords or a default
236
- if found_keywords:
237
- return ", ".join(found_keywords)
238
- else:
239
- return "musical"
240
-
241
- def _generate_long_audio_sliding_window(self, prompt, total_duration, params):
242
- """
243
- Generate long audio using Meta's sliding window approach:
244
- - Generate 30-second chunks
245
- - Slide window by 10 seconds
246
- - Crossfade overlapping sections to maintain continuity
247
- """
248
- # Initialize variables
249
- segment_duration = self.max_segment_duration # 30 seconds per segment
250
- slide_window = 10 # Slide by 10 seconds for each new segment
251
- overlap_duration = segment_duration - slide_window # 20 seconds of overlap
252
 
253
- # Number of segments needed (rounding up)
254
- num_segments = math.ceil((total_duration - overlap_duration) / slide_window) + 1
255
-
256
- # Initialize audio array
257
- final_audio = None
258
 
259
- # Setup generation kwargs
260
- generation_kwargs = {
261
- "do_sample": True,
262
- "guidance_scale": 3.0
263
- }
264
 
265
- # Add additional parameters if provided
266
- if "top_k" in params:
267
- generation_kwargs["top_k"] = min(int(params["top_k"]), 500)
268
 
269
- if "temperature" in params:
270
- temp = float(params["temperature"])
271
- if temp > 0.1:
272
- generation_kwargs["temperature"] = min(temp, 1.5)
273
- else:
274
- # A slightly lower temperature helps with style consistency
275
- generation_kwargs["temperature"] = 0.95
 
 
 
276
 
277
- if "guidance_scale" in params:
278
- generation_kwargs["guidance_scale"] = min(float(params["guidance_scale"]), 3.0)
279
- elif "cfg_coef" in params:
280
- generation_kwargs["guidance_scale"] = min(float(params["cfg_coef"]), 3.0)
 
 
 
 
 
 
281
 
282
- logger.info(f"Long audio generation using sliding window approach, {num_segments} segments")
283
- logger.info(f"Generation parameters: {generation_kwargs}")
 
 
 
 
 
 
 
 
 
284
 
285
- # Extract style keywords for better continuity
286
- style_keywords = self._extract_style_keywords(prompt)
 
287
 
288
- for i in range(num_segments):
 
289
  # Calculate segment duration
290
- if i == 0:
291
- # First segment is always the full segment duration
292
- current_segment_duration = segment_duration
293
- else:
294
- # Calculate remaining duration (accounting for overlap)
295
- remaining_duration = total_duration - (i * slide_window)
296
- if remaining_duration <= 0:
297
- break
298
-
299
- # Last segment might be shorter
300
- current_segment_duration = min(segment_duration, remaining_duration + overlap_duration)
301
 
302
  try:
303
- # Create enhanced prompt for better continuity
304
- if i == 0:
 
305
  segment_prompt = prompt
306
  else:
307
- # Add continuation instructions with style keywords
308
- segment_prompt = f"{prompt} [continuation keeping same {style_keywords} style]"
309
-
310
- logger.info(f"Generating segment {i+1}/{num_segments}, duration: {current_segment_duration:.1f}s")
311
- logger.info(f"Segment prompt: {segment_prompt}")
312
 
313
- # Process text input
314
  inputs = self.processor(
315
  text=[segment_prompt],
316
  padding=True,
@@ -318,82 +291,92 @@ class EndpointHandler:
318
  ).to("cuda")
319
 
320
  # Calculate max_new_tokens from duration
321
- max_new_tokens = int(current_segment_duration * 50)
322
- generation_kwargs["max_new_tokens"] = max_new_tokens
323
 
324
- # Generate segment
325
- with torch.inference_mode():
326
- outputs = self.model.generate(**inputs, **generation_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
- # Get the audio data
329
- segment_audio = outputs[0].cpu().numpy()
 
 
 
 
 
 
330
 
331
- # If this is the first segment, just keep it
332
- if final_audio is None:
333
- final_audio = segment_audio
334
  else:
335
- # For subsequent segments, we need to crossfade
336
- overlap_samples = int(overlap_duration * self.sampling_rate)
337
- crossfade_samples = int(3.0 * self.sampling_rate) # 3-second crossfade
338
-
339
- # Ensure the segment is long enough for crossfading
340
- if segment_audio.shape[1] < overlap_samples:
341
- logger.warning(f"Segment {i+1} too short for proper crossfade, using concatenation")
342
- final_audio = np.concatenate([final_audio, segment_audio], axis=1)
343
- else:
344
- # Calculate where to crossfade
345
- current_length = final_audio.shape[1]
346
- segment_offset = current_length - overlap_samples
347
-
348
- # Create a new combined audio array with room for the new segment
349
- new_length = segment_offset + segment_audio.shape[1]
350
- combined_audio = np.zeros((final_audio.shape[0], new_length), dtype=final_audio.dtype)
351
-
352
- # Copy the existing audio
353
- combined_audio[:, :segment_offset] = final_audio[:, :segment_offset]
354
-
355
- # Crossfade the overlapping region
356
- crossfade_region = min(crossfade_samples, overlap_samples)
357
-
358
- # Calculate crossfade weights (equal power)
359
- t = np.linspace(0, np.pi/2, crossfade_region)
360
- fade_out = np.cos(t)**2
361
- fade_in = np.sin(t)**2
362
-
363
- # Apply crossfade at the transition point
364
- for ch in range(final_audio.shape[0]):
365
- # Crossfade
366
- combined_audio[ch, segment_offset:segment_offset+crossfade_region] = (
367
- final_audio[ch, segment_offset:segment_offset+crossfade_region] * fade_out +
368
- segment_audio[ch, :crossfade_region] * fade_in
369
- )
370
-
371
- # Copy the rest of the new segment (after crossfade)
372
- combined_audio[ch, segment_offset+crossfade_region:] = segment_audio[ch, crossfade_region:]
373
-
374
- final_audio = combined_audio
375
 
376
  # Clear CUDA cache
377
  torch.cuda.empty_cache()
378
 
 
 
 
 
379
  except Exception as e:
380
- logger.error(f"Error generating segment {i+1}: {e}")
381
- # If we have some output, return it
382
  if final_audio is not None:
383
- break
384
- # Otherwise, try a simpler approach for at least some output
385
- return self._generate_short_audio(prompt, min(segment_duration, total_duration), params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
- # Apply a smooth fade-out at the end
388
- if final_audio.shape[1] > self.sampling_rate:
389
- fade_samples = min(int(1.0 * self.sampling_rate), final_audio.shape[1] // 10) # 1-second fade out
390
- fade_out = np.linspace(1.0, 0.0, fade_samples)**0.5 # Smooth curve
391
  for ch in range(final_audio.shape[0]):
392
  final_audio[ch, -fade_samples:] *= fade_out
393
 
394
- # Trim to requested duration if needed
395
- max_samples = int(total_duration * self.sampling_rate)
396
- if final_audio.shape[1] > max_samples:
397
- final_audio = final_audio[:, :max_samples]
398
-
399
  return final_audio
 
75
  logger.info(f"Requested duration: {duration} seconds")
76
 
77
  # Generate audio
78
+ if duration <= self.max_segment_duration - 5: # For short durations, generate in one go
79
  audio_output = self._generate_short_audio(prompt, duration, parameters)
80
  else:
81
+ # Use basic segmentation for longer durations
82
+ audio_output = self._generate_long_audio(prompt, duration, parameters)
83
 
84
  # Monitor GPU memory after generation
85
  allocated = torch.cuda.memory_allocated() / 1e9
 
137
 
138
  # Generate audio
139
  logger.info(f"Generation parameters: {generation_kwargs}")
140
+ outputs = self.model.generate(**inputs, **generation_kwargs)
 
141
 
142
  # Return audio
143
  return outputs[0].cpu().numpy()
 
156
  ).to("cuda")
157
 
158
  # Generate with minimal parameters
159
+ outputs = self.model.generate(
160
+ **inputs,
161
+ max_new_tokens=max_new_tokens,
162
+ do_sample=True,
163
+ guidance_scale=1.0 # Minimal guidance
164
+ )
 
165
 
166
  return outputs[0].cpu().numpy()
167
  except Exception as e2:
168
  logger.error(f"Second attempt failed: {e2}")
169
  raise e2
170
 
171
+ def _simple_crossfade(self, segment1, segment2, overlap_samples):
172
+ """Apply a simple linear crossfade between segments."""
173
  # Get the length of the segments
174
  length1 = segment1.shape[1]
175
  length2 = segment2.shape[1]
 
187
  # Copy the non-overlapping part of segment2
188
  result[:, length1:] = segment2[:, overlap_samples:]
189
 
190
+ # Apply simple linear crossfade to the overlapping parts
191
  if overlap_samples > 0:
192
+ # Linear fade factors
193
+ fade_out = np.linspace(1, 0, overlap_samples)
194
+ fade_in = np.linspace(0, 1, overlap_samples)
 
195
 
196
  # Get the overlapping parts
197
  segment1_end = segment1[:, -overlap_samples:].copy()
 
210
 
211
  return result
212
 
213
+ def _advanced_crossfade(self, segment1, segment2, overlap_samples):
214
+ """Apply an advanced equal-power crossfade between segments."""
215
+ # Get the length of the segments
216
+ length1 = segment1.shape[1]
217
+ length2 = segment2.shape[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
+ # Ensure we have enough samples for crossfading
220
+ overlap_samples = min(overlap_samples, length1, length2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
+ # Create the result array (total length minus overlap)
223
+ result_length = length1 + length2 - overlap_samples
224
+ result = np.zeros((segment1.shape[0], result_length), dtype=segment1.dtype)
 
 
225
 
226
+ # Copy the non-overlapping part of segment1
227
+ result[:, :length1-overlap_samples] = segment1[:, :length1-overlap_samples]
 
 
 
228
 
229
+ # Copy the non-overlapping part of segment2
230
+ result[:, length1:] = segment2[:, overlap_samples:]
 
231
 
232
+ # Apply equal-power crossfade to the overlapping parts
233
+ if overlap_samples > 0:
234
+ # Equal power crossfade curves (cosine/sine based)
235
+ t = np.linspace(0, np.pi/2, overlap_samples)
236
+ fade_out = np.cos(t)**2
237
+ fade_in = np.sin(t)**2
238
+
239
+ # Get the overlapping parts
240
+ segment1_end = segment1[:, -overlap_samples:].copy()
241
+ segment2_start = segment2[:, :overlap_samples].copy()
242
 
243
+ # Apply the fades
244
+ for ch in range(segment1_end.shape[0]):
245
+ segment1_end[ch] *= fade_out
246
+ segment2_start[ch] *= fade_in
247
+
248
+ # Combine the faded parts
249
+ crossfaded = segment1_end + segment2_start
250
+
251
+ # Add to the result
252
+ result[:, length1-overlap_samples:length1] = crossfaded
253
 
254
+ return result
255
+
256
+ def _generate_long_audio(self, prompt, total_duration, params):
257
+ """Generate long audio with improved segment continuity."""
258
+ # Overlap duration for crossfade
259
+ overlap_duration = 5 # Using a longer overlap for better transitions
260
+
261
+ # Initialize variables
262
+ remaining_duration = total_duration
263
+ final_audio = None
264
+ segment_idx = 0
265
 
266
+ # Calculate number of segments needed
267
+ segment_duration = self.max_segment_duration
268
+ overlap_samples = int(overlap_duration * self.sampling_rate)
269
 
270
+ # Process in segments
271
+ while remaining_duration > 0:
272
  # Calculate segment duration
273
+ target_duration = min(segment_duration, remaining_duration + (segment_idx > 0) * overlap_duration)
274
+
275
+ logger.info(f"Generating segment {segment_idx+1}, duration: {target_duration:.1f}s")
 
 
 
 
 
 
 
 
276
 
277
  try:
278
+ # The main change: We directly use continuation prompts without trying prompt_audio
279
+ if segment_idx == 0:
280
+ # First segment with basic prompt
281
  segment_prompt = prompt
282
  else:
283
+ # Subsequent segments with enhanced continuation prompt
284
+ segment_prompt = f"{prompt} [continuing segment {segment_idx+1}, seamless continuation]"
 
 
 
285
 
286
+ # Process text for this segment
287
  inputs = self.processor(
288
  text=[segment_prompt],
289
  padding=True,
 
291
  ).to("cuda")
292
 
293
  # Calculate max_new_tokens from duration
294
+ max_new_tokens = int(target_duration * 50)
 
295
 
296
+ # Generation parameters for transformers implementation
297
+ generation_kwargs = {
298
+ "max_new_tokens": max_new_tokens,
299
+ "do_sample": True,
300
+ "guidance_scale": 3.0
301
+ }
302
+
303
+ # Add additional parameters if provided
304
+ if "top_k" in params:
305
+ generation_kwargs["top_k"] = min(int(params["top_k"]), 500)
306
+
307
+ if "temperature" in params:
308
+ temp = float(params["temperature"])
309
+ if temp > 0.1:
310
+ generation_kwargs["temperature"] = min(temp, 1.5)
311
 
312
+ if "guidance_scale" in params:
313
+ generation_kwargs["guidance_scale"] = min(float(params["guidance_scale"]), 3.0)
314
+ elif "cfg_coef" in params:
315
+ generation_kwargs["guidance_scale"] = min(float(params["cfg_coef"]), 3.0)
316
+
317
+ # Generate this segment
318
+ outputs = self.model.generate(**inputs, **generation_kwargs)
319
+ segment_output = outputs[0].cpu().numpy()
320
 
321
+ # Add this segment to our final output
322
+ if segment_idx == 0:
323
+ final_audio = segment_output
324
  else:
325
+ # Apply advanced crossfade for better transitions
326
+ final_audio = self._advanced_crossfade(final_audio, segment_output, overlap_samples)
327
+
328
+ # Update remaining duration
329
+ if segment_idx == 0:
330
+ remaining_duration -= target_duration
331
+ else:
332
+ remaining_duration -= (target_duration - overlap_duration)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
  # Clear CUDA cache
335
  torch.cuda.empty_cache()
336
 
337
+ # Log progress
338
+ logger.info(f"GPU memory usage: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
339
+ logger.info(f"Remaining duration: {remaining_duration:.1f}s")
340
+
341
  except Exception as e:
342
+ logger.error(f"Error generating segment {segment_idx+1}: {e}")
 
343
  if final_audio is not None:
344
+ logger.info("Returning partial audio after error")
345
+ return final_audio
346
+
347
+ # Try again with minimal parameters
348
+ try:
349
+ logger.info("Trying minimal generation parameters")
350
+ inputs = self.processor(
351
+ text=[prompt],
352
+ padding=True,
353
+ return_tensors="pt",
354
+ ).to("cuda")
355
+
356
+ outputs = self.model.generate(
357
+ **inputs,
358
+ max_new_tokens=int(min(target_duration, 15.0) * 50),
359
+ do_sample=True
360
+ )
361
+
362
+ return outputs[0].cpu().numpy()
363
+ except Exception as e2:
364
+ logger.error(f"Minimal generation also failed: {e2}")
365
+ raise e2
366
+
367
+ # Move to next segment
368
+ segment_idx += 1
369
+
370
+ # Break if we've generated enough audio
371
+ if remaining_duration <= 0:
372
+ break
373
 
374
+ # Apply a smooth fade out to the last 0.5 seconds
375
+ if final_audio.shape[1] > self.sampling_rate // 2:
376
+ fade_samples = self.sampling_rate // 2 # 0.5 seconds
377
+ fade_out = np.linspace(1.0, 0.0, fade_samples)**0.7 # Smooth curve
378
  for ch in range(final_audio.shape[0]):
379
  final_audio[ch, -fade_samples:] *= fade_out
380
 
381
+ # Return the final audio
 
 
 
 
382
  return final_audio