|
import logging |
|
from typing import Dict, Any, List |
|
import torch |
|
import numpy as np |
|
import math |
|
from transformers import AutoProcessor, MusicgenForConditionalGeneration |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
"""Initialize the endpoint handler with optimized model loading.""" |
|
|
|
logger.info("Initializing processor") |
|
self.processor = AutoProcessor.from_pretrained(path) |
|
|
|
|
|
logger.info("Loading model in standard precision") |
|
self.model = MusicgenForConditionalGeneration.from_pretrained( |
|
path, |
|
torch_dtype=torch.float32 |
|
) |
|
|
|
|
|
logger.info("Moving model to CUDA") |
|
self.model = self.model.to("cuda") |
|
|
|
|
|
self.sampling_rate = self.model.config.audio_encoder.sampling_rate |
|
self.max_segment_duration = 30 |
|
|
|
|
|
allocated = torch.cuda.memory_allocated() / 1e9 |
|
reserved = torch.cuda.memory_reserved() / 1e9 |
|
logger.info(f"Initial GPU memory allocated: {allocated:.2f} GB") |
|
logger.info(f"GPU memory reserved: {reserved:.2f} GB") |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Any: |
|
""" |
|
Process the incoming request data and generate audio. |
|
|
|
Args: |
|
data (dict): The payload with the text prompt and generation parameters. |
|
""" |
|
try: |
|
|
|
inputs = data.get("inputs", data) |
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
if isinstance(inputs, str): |
|
prompt = inputs |
|
duration = parameters.get('duration', 10) |
|
elif isinstance(inputs, dict): |
|
prompt = inputs.get("text") or inputs.get("prompt") |
|
duration = inputs.get("duration", parameters.get('duration', 10)) |
|
else: |
|
prompt = None |
|
duration = parameters.get('duration', 10) |
|
|
|
|
|
if 'duration' in parameters: |
|
duration = parameters.pop('duration') |
|
|
|
|
|
duration = min(float(duration), 300) |
|
|
|
|
|
if not prompt: |
|
return {"error": "No prompt provided."} |
|
|
|
logger.info(f"Received prompt: {prompt}") |
|
logger.info(f"Requested duration: {duration} seconds") |
|
|
|
|
|
if duration <= self.max_segment_duration - 5: |
|
audio_output = self._generate_short_audio(prompt, duration, parameters) |
|
else: |
|
|
|
audio_output = self._generate_long_audio(prompt, duration, parameters) |
|
|
|
|
|
allocated = torch.cuda.memory_allocated() / 1e9 |
|
logger.info(f"Post-generation GPU memory: {allocated:.2f} GB") |
|
|
|
return [ |
|
{ |
|
"generated_audio": audio_output.tolist(), |
|
"sample_rate": self.sampling_rate, |
|
} |
|
] |
|
|
|
except Exception as e: |
|
logger.error(f"Exception during generation: {e}") |
|
import traceback |
|
logger.error(traceback.format_exc()) |
|
return {"error": str(e)} |
|
|
|
def _generate_short_audio(self, prompt, duration, params): |
|
"""Generate a single audio segment using the transformers API.""" |
|
logger.info(f"Generating short audio segment: {duration}s") |
|
|
|
try: |
|
|
|
inputs = self.processor( |
|
text=[prompt], |
|
padding=True, |
|
return_tensors="pt", |
|
).to("cuda") |
|
|
|
|
|
|
|
max_new_tokens = int(duration * 50) |
|
|
|
|
|
generation_kwargs = { |
|
"max_new_tokens": max_new_tokens, |
|
"do_sample": True, |
|
"guidance_scale": 3.0 |
|
} |
|
|
|
|
|
if "top_k" in params: |
|
generation_kwargs["top_k"] = min(int(params["top_k"]), 500) |
|
|
|
if "temperature" in params: |
|
temp = float(params["temperature"]) |
|
if temp > 0.1: |
|
generation_kwargs["temperature"] = min(temp, 1.5) |
|
|
|
if "guidance_scale" in params: |
|
generation_kwargs["guidance_scale"] = min(float(params["guidance_scale"]), 3.0) |
|
elif "cfg_coef" in params: |
|
generation_kwargs["guidance_scale"] = min(float(params["cfg_coef"]), 3.0) |
|
|
|
|
|
logger.info(f"Generation parameters: {generation_kwargs}") |
|
outputs = self.model.generate(**inputs, **generation_kwargs) |
|
|
|
|
|
return outputs[0].cpu().numpy() |
|
|
|
except Exception as e: |
|
logger.error(f"Error during generation: {e}") |
|
|
|
try: |
|
logger.info("Trying generation with minimal parameters") |
|
|
|
|
|
inputs = self.processor( |
|
text=[prompt], |
|
padding=True, |
|
return_tensors="pt", |
|
).to("cuda") |
|
|
|
|
|
outputs = self.model.generate( |
|
**inputs, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
guidance_scale=1.0 |
|
) |
|
|
|
return outputs[0].cpu().numpy() |
|
except Exception as e2: |
|
logger.error(f"Second attempt failed: {e2}") |
|
raise e2 |
|
|
|
def _simple_crossfade(self, segment1, segment2, overlap_samples): |
|
"""Apply a simple linear crossfade between segments.""" |
|
|
|
length1 = segment1.shape[1] |
|
length2 = segment2.shape[1] |
|
|
|
|
|
overlap_samples = min(overlap_samples, length1, length2) |
|
|
|
|
|
result_length = length1 + length2 - overlap_samples |
|
result = np.zeros((segment1.shape[0], result_length), dtype=segment1.dtype) |
|
|
|
|
|
result[:, :length1-overlap_samples] = segment1[:, :length1-overlap_samples] |
|
|
|
|
|
result[:, length1:] = segment2[:, overlap_samples:] |
|
|
|
|
|
if overlap_samples > 0: |
|
|
|
fade_out = np.linspace(1, 0, overlap_samples) |
|
fade_in = np.linspace(0, 1, overlap_samples) |
|
|
|
|
|
segment1_end = segment1[:, -overlap_samples:].copy() |
|
segment2_start = segment2[:, :overlap_samples].copy() |
|
|
|
|
|
for ch in range(segment1_end.shape[0]): |
|
segment1_end[ch] *= fade_out |
|
segment2_start[ch] *= fade_in |
|
|
|
|
|
crossfaded = segment1_end + segment2_start |
|
|
|
|
|
result[:, length1-overlap_samples:length1] = crossfaded |
|
|
|
return result |
|
|
|
def _advanced_crossfade(self, segment1, segment2, overlap_samples): |
|
"""Apply an advanced equal-power crossfade between segments.""" |
|
|
|
length1 = segment1.shape[1] |
|
length2 = segment2.shape[1] |
|
|
|
|
|
overlap_samples = min(overlap_samples, length1, length2) |
|
|
|
|
|
result_length = length1 + length2 - overlap_samples |
|
result = np.zeros((segment1.shape[0], result_length), dtype=segment1.dtype) |
|
|
|
|
|
result[:, :length1-overlap_samples] = segment1[:, :length1-overlap_samples] |
|
|
|
|
|
result[:, length1:] = segment2[:, overlap_samples:] |
|
|
|
|
|
if overlap_samples > 0: |
|
|
|
t = np.linspace(0, np.pi/2, overlap_samples) |
|
fade_out = np.cos(t)**2 |
|
fade_in = np.sin(t)**2 |
|
|
|
|
|
segment1_end = segment1[:, -overlap_samples:].copy() |
|
segment2_start = segment2[:, :overlap_samples].copy() |
|
|
|
|
|
for ch in range(segment1_end.shape[0]): |
|
segment1_end[ch] *= fade_out |
|
segment2_start[ch] *= fade_in |
|
|
|
|
|
crossfaded = segment1_end + segment2_start |
|
|
|
|
|
result[:, length1-overlap_samples:length1] = crossfaded |
|
|
|
return result |
|
|
|
def _generate_long_audio(self, prompt, total_duration, params): |
|
"""Generate long audio with improved segment continuity.""" |
|
|
|
overlap_duration = 5 |
|
|
|
|
|
remaining_duration = total_duration |
|
final_audio = None |
|
segment_idx = 0 |
|
|
|
|
|
segment_duration = self.max_segment_duration |
|
overlap_samples = int(overlap_duration * self.sampling_rate) |
|
|
|
|
|
while remaining_duration > 0: |
|
|
|
target_duration = min(segment_duration, remaining_duration + (segment_idx > 0) * overlap_duration) |
|
|
|
logger.info(f"Generating segment {segment_idx+1}, duration: {target_duration:.1f}s") |
|
|
|
try: |
|
|
|
if segment_idx == 0: |
|
|
|
segment_prompt = prompt |
|
else: |
|
|
|
segment_prompt = f"{prompt} [continuing segment {segment_idx+1}, seamless continuation]" |
|
|
|
|
|
inputs = self.processor( |
|
text=[segment_prompt], |
|
padding=True, |
|
return_tensors="pt", |
|
).to("cuda") |
|
|
|
|
|
max_new_tokens = int(target_duration * 50) |
|
|
|
|
|
generation_kwargs = { |
|
"max_new_tokens": max_new_tokens, |
|
"do_sample": True, |
|
"guidance_scale": 3.0 |
|
} |
|
|
|
|
|
if "top_k" in params: |
|
generation_kwargs["top_k"] = min(int(params["top_k"]), 500) |
|
|
|
if "temperature" in params: |
|
temp = float(params["temperature"]) |
|
if temp > 0.1: |
|
generation_kwargs["temperature"] = min(temp, 1.5) |
|
|
|
if "guidance_scale" in params: |
|
generation_kwargs["guidance_scale"] = min(float(params["guidance_scale"]), 3.0) |
|
elif "cfg_coef" in params: |
|
generation_kwargs["guidance_scale"] = min(float(params["cfg_coef"]), 3.0) |
|
|
|
|
|
outputs = self.model.generate(**inputs, **generation_kwargs) |
|
segment_output = outputs[0].cpu().numpy() |
|
|
|
|
|
if segment_idx == 0: |
|
final_audio = segment_output |
|
else: |
|
|
|
final_audio = self._advanced_crossfade(final_audio, segment_output, overlap_samples) |
|
|
|
|
|
if segment_idx == 0: |
|
remaining_duration -= target_duration |
|
else: |
|
remaining_duration -= (target_duration - overlap_duration) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
logger.info(f"GPU memory usage: {torch.cuda.memory_allocated() / 1e9:.2f} GB") |
|
logger.info(f"Remaining duration: {remaining_duration:.1f}s") |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating segment {segment_idx+1}: {e}") |
|
if final_audio is not None: |
|
logger.info("Returning partial audio after error") |
|
return final_audio |
|
|
|
|
|
try: |
|
logger.info("Trying minimal generation parameters") |
|
inputs = self.processor( |
|
text=[prompt], |
|
padding=True, |
|
return_tensors="pt", |
|
).to("cuda") |
|
|
|
outputs = self.model.generate( |
|
**inputs, |
|
max_new_tokens=int(min(target_duration, 15.0) * 50), |
|
do_sample=True |
|
) |
|
|
|
return outputs[0].cpu().numpy() |
|
except Exception as e2: |
|
logger.error(f"Minimal generation also failed: {e2}") |
|
raise e2 |
|
|
|
|
|
segment_idx += 1 |
|
|
|
|
|
if remaining_duration <= 0: |
|
break |
|
|
|
|
|
if final_audio.shape[1] > self.sampling_rate // 2: |
|
fade_samples = self.sampling_rate // 2 |
|
fade_out = np.linspace(1.0, 0.0, fade_samples)**0.7 |
|
for ch in range(final_audio.shape[0]): |
|
final_audio[ch, -fade_samples:] *= fade_out |
|
|
|
|
|
return final_audio |