PyTorch
musicgen
Phoenixak99's picture
Update handler.py
90cb56c verified
import logging
from typing import Dict, Any, List
import torch
import numpy as np
import math
from transformers import AutoProcessor, MusicgenForConditionalGeneration
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path=""):
"""Initialize the endpoint handler with optimized model loading."""
# Load the processor
logger.info("Initializing processor")
self.processor = AutoProcessor.from_pretrained(path)
# Load model using standard precision
logger.info("Loading model in standard precision")
self.model = MusicgenForConditionalGeneration.from_pretrained(
path,
torch_dtype=torch.float32 # Use standard precision for better stability
)
# Move model to CUDA after loading
logger.info("Moving model to CUDA")
self.model = self.model.to("cuda")
# Get the model's audio configuration
self.sampling_rate = self.model.config.audio_encoder.sampling_rate
self.max_segment_duration = 30 # Maximum duration per segment in seconds
# Log GPU memory info
allocated = torch.cuda.memory_allocated() / 1e9
reserved = torch.cuda.memory_reserved() / 1e9
logger.info(f"Initial GPU memory allocated: {allocated:.2f} GB")
logger.info(f"GPU memory reserved: {reserved:.2f} GB")
def __call__(self, data: Dict[str, Any]) -> Any:
"""
Process the incoming request data and generate audio.
Args:
data (dict): The payload with the text prompt and generation parameters.
"""
try:
# Extract inputs and parameters from the payload
inputs = data.get("inputs", data)
parameters = data.get("parameters", {})
# Handle inputs
if isinstance(inputs, str):
prompt = inputs
duration = parameters.get('duration', 10)
elif isinstance(inputs, dict):
prompt = inputs.get("text") or inputs.get("prompt")
duration = inputs.get("duration", parameters.get('duration', 10))
else:
prompt = None
duration = parameters.get('duration', 10)
# Override duration if provided in parameters
if 'duration' in parameters:
duration = parameters.pop('duration')
# Cap duration to prevent excessive resource usage
duration = min(float(duration), 300)
# Validate the prompt
if not prompt:
return {"error": "No prompt provided."}
logger.info(f"Received prompt: {prompt}")
logger.info(f"Requested duration: {duration} seconds")
# Generate audio
if duration <= self.max_segment_duration - 5: # For short durations, generate in one go
audio_output = self._generate_short_audio(prompt, duration, parameters)
else:
# Use basic segmentation for longer durations
audio_output = self._generate_long_audio(prompt, duration, parameters)
# Monitor GPU memory after generation
allocated = torch.cuda.memory_allocated() / 1e9
logger.info(f"Post-generation GPU memory: {allocated:.2f} GB")
return [
{
"generated_audio": audio_output.tolist(),
"sample_rate": self.sampling_rate,
}
]
except Exception as e:
logger.error(f"Exception during generation: {e}")
import traceback
logger.error(traceback.format_exc())
return {"error": str(e)}
def _generate_short_audio(self, prompt, duration, params):
"""Generate a single audio segment using the transformers API."""
logger.info(f"Generating short audio segment: {duration}s")
try:
# Process text input
inputs = self.processor(
text=[prompt],
padding=True,
return_tensors="pt",
).to("cuda")
# Calculate max_new_tokens from duration
# Each second is approximately 50 tokens at 32kHz
max_new_tokens = int(duration * 50)
# Generation parameters for transformers implementation
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": True,
"guidance_scale": 3.0 # Default classifier-free guidance scale
}
# Add additional parameters if provided
if "top_k" in params:
generation_kwargs["top_k"] = min(int(params["top_k"]), 500)
if "temperature" in params:
temp = float(params["temperature"])
if temp > 0.1: # Avoid zero temperature
generation_kwargs["temperature"] = min(temp, 1.5)
if "guidance_scale" in params:
generation_kwargs["guidance_scale"] = min(float(params["guidance_scale"]), 3.0)
elif "cfg_coef" in params:
generation_kwargs["guidance_scale"] = min(float(params["cfg_coef"]), 3.0)
# Generate audio
logger.info(f"Generation parameters: {generation_kwargs}")
outputs = self.model.generate(**inputs, **generation_kwargs)
# Return audio
return outputs[0].cpu().numpy()
except Exception as e:
logger.error(f"Error during generation: {e}")
# Try with minimal parameters
try:
logger.info("Trying generation with minimal parameters")
# Process text using processor
inputs = self.processor(
text=[prompt],
padding=True,
return_tensors="pt",
).to("cuda")
# Generate with minimal parameters
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
guidance_scale=1.0 # Minimal guidance
)
return outputs[0].cpu().numpy()
except Exception as e2:
logger.error(f"Second attempt failed: {e2}")
raise e2
def _simple_crossfade(self, segment1, segment2, overlap_samples):
"""Apply a simple linear crossfade between segments."""
# Get the length of the segments
length1 = segment1.shape[1]
length2 = segment2.shape[1]
# Ensure we have enough samples for crossfading
overlap_samples = min(overlap_samples, length1, length2)
# Create the result array (total length minus overlap)
result_length = length1 + length2 - overlap_samples
result = np.zeros((segment1.shape[0], result_length), dtype=segment1.dtype)
# Copy the non-overlapping part of segment1
result[:, :length1-overlap_samples] = segment1[:, :length1-overlap_samples]
# Copy the non-overlapping part of segment2
result[:, length1:] = segment2[:, overlap_samples:]
# Apply simple linear crossfade to the overlapping parts
if overlap_samples > 0:
# Linear fade factors
fade_out = np.linspace(1, 0, overlap_samples)
fade_in = np.linspace(0, 1, overlap_samples)
# Get the overlapping parts
segment1_end = segment1[:, -overlap_samples:].copy()
segment2_start = segment2[:, :overlap_samples].copy()
# Apply the fades
for ch in range(segment1_end.shape[0]):
segment1_end[ch] *= fade_out
segment2_start[ch] *= fade_in
# Combine the faded parts
crossfaded = segment1_end + segment2_start
# Add to the result
result[:, length1-overlap_samples:length1] = crossfaded
return result
def _advanced_crossfade(self, segment1, segment2, overlap_samples):
"""Apply an advanced equal-power crossfade between segments."""
# Get the length of the segments
length1 = segment1.shape[1]
length2 = segment2.shape[1]
# Ensure we have enough samples for crossfading
overlap_samples = min(overlap_samples, length1, length2)
# Create the result array (total length minus overlap)
result_length = length1 + length2 - overlap_samples
result = np.zeros((segment1.shape[0], result_length), dtype=segment1.dtype)
# Copy the non-overlapping part of segment1
result[:, :length1-overlap_samples] = segment1[:, :length1-overlap_samples]
# Copy the non-overlapping part of segment2
result[:, length1:] = segment2[:, overlap_samples:]
# Apply equal-power crossfade to the overlapping parts
if overlap_samples > 0:
# Equal power crossfade curves (cosine/sine based)
t = np.linspace(0, np.pi/2, overlap_samples)
fade_out = np.cos(t)**2
fade_in = np.sin(t)**2
# Get the overlapping parts
segment1_end = segment1[:, -overlap_samples:].copy()
segment2_start = segment2[:, :overlap_samples].copy()
# Apply the fades
for ch in range(segment1_end.shape[0]):
segment1_end[ch] *= fade_out
segment2_start[ch] *= fade_in
# Combine the faded parts
crossfaded = segment1_end + segment2_start
# Add to the result
result[:, length1-overlap_samples:length1] = crossfaded
return result
def _generate_long_audio(self, prompt, total_duration, params):
"""Generate long audio with improved segment continuity."""
# Overlap duration for crossfade
overlap_duration = 5 # Using a longer overlap for better transitions
# Initialize variables
remaining_duration = total_duration
final_audio = None
segment_idx = 0
# Calculate number of segments needed
segment_duration = self.max_segment_duration
overlap_samples = int(overlap_duration * self.sampling_rate)
# Process in segments
while remaining_duration > 0:
# Calculate segment duration
target_duration = min(segment_duration, remaining_duration + (segment_idx > 0) * overlap_duration)
logger.info(f"Generating segment {segment_idx+1}, duration: {target_duration:.1f}s")
try:
# The main change: We directly use continuation prompts without trying prompt_audio
if segment_idx == 0:
# First segment with basic prompt
segment_prompt = prompt
else:
# Subsequent segments with enhanced continuation prompt
segment_prompt = f"{prompt} [continuing segment {segment_idx+1}, seamless continuation]"
# Process text for this segment
inputs = self.processor(
text=[segment_prompt],
padding=True,
return_tensors="pt",
).to("cuda")
# Calculate max_new_tokens from duration
max_new_tokens = int(target_duration * 50)
# Generation parameters for transformers implementation
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": True,
"guidance_scale": 3.0
}
# Add additional parameters if provided
if "top_k" in params:
generation_kwargs["top_k"] = min(int(params["top_k"]), 500)
if "temperature" in params:
temp = float(params["temperature"])
if temp > 0.1:
generation_kwargs["temperature"] = min(temp, 1.5)
if "guidance_scale" in params:
generation_kwargs["guidance_scale"] = min(float(params["guidance_scale"]), 3.0)
elif "cfg_coef" in params:
generation_kwargs["guidance_scale"] = min(float(params["cfg_coef"]), 3.0)
# Generate this segment
outputs = self.model.generate(**inputs, **generation_kwargs)
segment_output = outputs[0].cpu().numpy()
# Add this segment to our final output
if segment_idx == 0:
final_audio = segment_output
else:
# Apply advanced crossfade for better transitions
final_audio = self._advanced_crossfade(final_audio, segment_output, overlap_samples)
# Update remaining duration
if segment_idx == 0:
remaining_duration -= target_duration
else:
remaining_duration -= (target_duration - overlap_duration)
# Clear CUDA cache
torch.cuda.empty_cache()
# Log progress
logger.info(f"GPU memory usage: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
logger.info(f"Remaining duration: {remaining_duration:.1f}s")
except Exception as e:
logger.error(f"Error generating segment {segment_idx+1}: {e}")
if final_audio is not None:
logger.info("Returning partial audio after error")
return final_audio
# Try again with minimal parameters
try:
logger.info("Trying minimal generation parameters")
inputs = self.processor(
text=[prompt],
padding=True,
return_tensors="pt",
).to("cuda")
outputs = self.model.generate(
**inputs,
max_new_tokens=int(min(target_duration, 15.0) * 50),
do_sample=True
)
return outputs[0].cpu().numpy()
except Exception as e2:
logger.error(f"Minimal generation also failed: {e2}")
raise e2
# Move to next segment
segment_idx += 1
# Break if we've generated enough audio
if remaining_duration <= 0:
break
# Apply a smooth fade out to the last 0.5 seconds
if final_audio.shape[1] > self.sampling_rate // 2:
fade_samples = self.sampling_rate // 2 # 0.5 seconds
fade_out = np.linspace(1.0, 0.0, fade_samples)**0.7 # Smooth curve
for ch in range(final_audio.shape[0]):
final_audio[ch, -fade_samples:] *= fade_out
# Return the final audio
return final_audio