import logging from typing import Dict, Any, List import torch import numpy as np import math from transformers import AutoProcessor, MusicgenForConditionalGeneration # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, path=""): """Initialize the endpoint handler with optimized model loading.""" # Load the processor logger.info("Initializing processor") self.processor = AutoProcessor.from_pretrained(path) # Load model using standard precision logger.info("Loading model in standard precision") self.model = MusicgenForConditionalGeneration.from_pretrained( path, torch_dtype=torch.float32 # Use standard precision for better stability ) # Move model to CUDA after loading logger.info("Moving model to CUDA") self.model = self.model.to("cuda") # Get the model's audio configuration self.sampling_rate = self.model.config.audio_encoder.sampling_rate self.max_segment_duration = 30 # Maximum duration per segment in seconds # Log GPU memory info allocated = torch.cuda.memory_allocated() / 1e9 reserved = torch.cuda.memory_reserved() / 1e9 logger.info(f"Initial GPU memory allocated: {allocated:.2f} GB") logger.info(f"GPU memory reserved: {reserved:.2f} GB") def __call__(self, data: Dict[str, Any]) -> Any: """ Process the incoming request data and generate audio. Args: data (dict): The payload with the text prompt and generation parameters. """ try: # Extract inputs and parameters from the payload inputs = data.get("inputs", data) parameters = data.get("parameters", {}) # Handle inputs if isinstance(inputs, str): prompt = inputs duration = parameters.get('duration', 10) elif isinstance(inputs, dict): prompt = inputs.get("text") or inputs.get("prompt") duration = inputs.get("duration", parameters.get('duration', 10)) else: prompt = None duration = parameters.get('duration', 10) # Override duration if provided in parameters if 'duration' in parameters: duration = parameters.pop('duration') # Cap duration to prevent excessive resource usage duration = min(float(duration), 300) # Validate the prompt if not prompt: return {"error": "No prompt provided."} logger.info(f"Received prompt: {prompt}") logger.info(f"Requested duration: {duration} seconds") # Generate audio if duration <= self.max_segment_duration - 5: # For short durations, generate in one go audio_output = self._generate_short_audio(prompt, duration, parameters) else: # Use basic segmentation for longer durations audio_output = self._generate_long_audio(prompt, duration, parameters) # Monitor GPU memory after generation allocated = torch.cuda.memory_allocated() / 1e9 logger.info(f"Post-generation GPU memory: {allocated:.2f} GB") return [ { "generated_audio": audio_output.tolist(), "sample_rate": self.sampling_rate, } ] except Exception as e: logger.error(f"Exception during generation: {e}") import traceback logger.error(traceback.format_exc()) return {"error": str(e)} def _generate_short_audio(self, prompt, duration, params): """Generate a single audio segment using the transformers API.""" logger.info(f"Generating short audio segment: {duration}s") try: # Process text input inputs = self.processor( text=[prompt], padding=True, return_tensors="pt", ).to("cuda") # Calculate max_new_tokens from duration # Each second is approximately 50 tokens at 32kHz max_new_tokens = int(duration * 50) # Generation parameters for transformers implementation generation_kwargs = { "max_new_tokens": max_new_tokens, "do_sample": True, "guidance_scale": 3.0 # Default classifier-free guidance scale } # Add additional parameters if provided if "top_k" in params: generation_kwargs["top_k"] = min(int(params["top_k"]), 500) if "temperature" in params: temp = float(params["temperature"]) if temp > 0.1: # Avoid zero temperature generation_kwargs["temperature"] = min(temp, 1.5) if "guidance_scale" in params: generation_kwargs["guidance_scale"] = min(float(params["guidance_scale"]), 3.0) elif "cfg_coef" in params: generation_kwargs["guidance_scale"] = min(float(params["cfg_coef"]), 3.0) # Generate audio logger.info(f"Generation parameters: {generation_kwargs}") outputs = self.model.generate(**inputs, **generation_kwargs) # Return audio return outputs[0].cpu().numpy() except Exception as e: logger.error(f"Error during generation: {e}") # Try with minimal parameters try: logger.info("Trying generation with minimal parameters") # Process text using processor inputs = self.processor( text=[prompt], padding=True, return_tensors="pt", ).to("cuda") # Generate with minimal parameters outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=True, guidance_scale=1.0 # Minimal guidance ) return outputs[0].cpu().numpy() except Exception as e2: logger.error(f"Second attempt failed: {e2}") raise e2 def _simple_crossfade(self, segment1, segment2, overlap_samples): """Apply a simple linear crossfade between segments.""" # Get the length of the segments length1 = segment1.shape[1] length2 = segment2.shape[1] # Ensure we have enough samples for crossfading overlap_samples = min(overlap_samples, length1, length2) # Create the result array (total length minus overlap) result_length = length1 + length2 - overlap_samples result = np.zeros((segment1.shape[0], result_length), dtype=segment1.dtype) # Copy the non-overlapping part of segment1 result[:, :length1-overlap_samples] = segment1[:, :length1-overlap_samples] # Copy the non-overlapping part of segment2 result[:, length1:] = segment2[:, overlap_samples:] # Apply simple linear crossfade to the overlapping parts if overlap_samples > 0: # Linear fade factors fade_out = np.linspace(1, 0, overlap_samples) fade_in = np.linspace(0, 1, overlap_samples) # Get the overlapping parts segment1_end = segment1[:, -overlap_samples:].copy() segment2_start = segment2[:, :overlap_samples].copy() # Apply the fades for ch in range(segment1_end.shape[0]): segment1_end[ch] *= fade_out segment2_start[ch] *= fade_in # Combine the faded parts crossfaded = segment1_end + segment2_start # Add to the result result[:, length1-overlap_samples:length1] = crossfaded return result def _advanced_crossfade(self, segment1, segment2, overlap_samples): """Apply an advanced equal-power crossfade between segments.""" # Get the length of the segments length1 = segment1.shape[1] length2 = segment2.shape[1] # Ensure we have enough samples for crossfading overlap_samples = min(overlap_samples, length1, length2) # Create the result array (total length minus overlap) result_length = length1 + length2 - overlap_samples result = np.zeros((segment1.shape[0], result_length), dtype=segment1.dtype) # Copy the non-overlapping part of segment1 result[:, :length1-overlap_samples] = segment1[:, :length1-overlap_samples] # Copy the non-overlapping part of segment2 result[:, length1:] = segment2[:, overlap_samples:] # Apply equal-power crossfade to the overlapping parts if overlap_samples > 0: # Equal power crossfade curves (cosine/sine based) t = np.linspace(0, np.pi/2, overlap_samples) fade_out = np.cos(t)**2 fade_in = np.sin(t)**2 # Get the overlapping parts segment1_end = segment1[:, -overlap_samples:].copy() segment2_start = segment2[:, :overlap_samples].copy() # Apply the fades for ch in range(segment1_end.shape[0]): segment1_end[ch] *= fade_out segment2_start[ch] *= fade_in # Combine the faded parts crossfaded = segment1_end + segment2_start # Add to the result result[:, length1-overlap_samples:length1] = crossfaded return result def _generate_long_audio(self, prompt, total_duration, params): """Generate long audio with improved segment continuity.""" # Overlap duration for crossfade overlap_duration = 5 # Using a longer overlap for better transitions # Initialize variables remaining_duration = total_duration final_audio = None segment_idx = 0 # Calculate number of segments needed segment_duration = self.max_segment_duration overlap_samples = int(overlap_duration * self.sampling_rate) # Process in segments while remaining_duration > 0: # Calculate segment duration target_duration = min(segment_duration, remaining_duration + (segment_idx > 0) * overlap_duration) logger.info(f"Generating segment {segment_idx+1}, duration: {target_duration:.1f}s") try: # The main change: We directly use continuation prompts without trying prompt_audio if segment_idx == 0: # First segment with basic prompt segment_prompt = prompt else: # Subsequent segments with enhanced continuation prompt segment_prompt = f"{prompt} [continuing segment {segment_idx+1}, seamless continuation]" # Process text for this segment inputs = self.processor( text=[segment_prompt], padding=True, return_tensors="pt", ).to("cuda") # Calculate max_new_tokens from duration max_new_tokens = int(target_duration * 50) # Generation parameters for transformers implementation generation_kwargs = { "max_new_tokens": max_new_tokens, "do_sample": True, "guidance_scale": 3.0 } # Add additional parameters if provided if "top_k" in params: generation_kwargs["top_k"] = min(int(params["top_k"]), 500) if "temperature" in params: temp = float(params["temperature"]) if temp > 0.1: generation_kwargs["temperature"] = min(temp, 1.5) if "guidance_scale" in params: generation_kwargs["guidance_scale"] = min(float(params["guidance_scale"]), 3.0) elif "cfg_coef" in params: generation_kwargs["guidance_scale"] = min(float(params["cfg_coef"]), 3.0) # Generate this segment outputs = self.model.generate(**inputs, **generation_kwargs) segment_output = outputs[0].cpu().numpy() # Add this segment to our final output if segment_idx == 0: final_audio = segment_output else: # Apply advanced crossfade for better transitions final_audio = self._advanced_crossfade(final_audio, segment_output, overlap_samples) # Update remaining duration if segment_idx == 0: remaining_duration -= target_duration else: remaining_duration -= (target_duration - overlap_duration) # Clear CUDA cache torch.cuda.empty_cache() # Log progress logger.info(f"GPU memory usage: {torch.cuda.memory_allocated() / 1e9:.2f} GB") logger.info(f"Remaining duration: {remaining_duration:.1f}s") except Exception as e: logger.error(f"Error generating segment {segment_idx+1}: {e}") if final_audio is not None: logger.info("Returning partial audio after error") return final_audio # Try again with minimal parameters try: logger.info("Trying minimal generation parameters") inputs = self.processor( text=[prompt], padding=True, return_tensors="pt", ).to("cuda") outputs = self.model.generate( **inputs, max_new_tokens=int(min(target_duration, 15.0) * 50), do_sample=True ) return outputs[0].cpu().numpy() except Exception as e2: logger.error(f"Minimal generation also failed: {e2}") raise e2 # Move to next segment segment_idx += 1 # Break if we've generated enough audio if remaining_duration <= 0: break # Apply a smooth fade out to the last 0.5 seconds if final_audio.shape[1] > self.sampling_rate // 2: fade_samples = self.sampling_rate // 2 # 0.5 seconds fade_out = np.linspace(1.0, 0.0, fade_samples)**0.7 # Smooth curve for ch in range(final_audio.shape[0]): final_audio[ch, -fade_samples:] *= fade_out # Return the final audio return final_audio