|
from typing import Dict, Any |
|
import torch |
|
import numpy as np |
|
import json |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
"""Initialize model on startup""" |
|
try: |
|
from audiocraft.models import MusicGen |
|
|
|
|
|
self.model = MusicGen.get_pretrained('melody') |
|
self.sample_rate = self.model.sample_rate |
|
|
|
|
|
self.model.set_generation_params( |
|
use_sampling=True, |
|
top_k=250, |
|
duration=30 |
|
) |
|
except Exception as e: |
|
|
|
print(f"CRITICAL: Failed to initialize model: {e}") |
|
raise |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
"""Handle prediction requests""" |
|
try: |
|
|
|
inputs = data.get("inputs", {}) |
|
|
|
|
|
if isinstance(inputs, dict): |
|
prompt = inputs.get("prompt", "") |
|
request_duration = float(inputs.get("duration", 10.0)) |
|
else: |
|
prompt = inputs |
|
request_duration = 10.0 |
|
|
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
if "duration" in parameters: |
|
request_duration = float(parameters.get("duration", request_duration)) |
|
|
|
|
|
duration = min(request_duration, 300.0) |
|
|
|
|
|
if not prompt: |
|
return {"error": "No prompt provided"} |
|
|
|
|
|
audio_data = self.generate_music(prompt, duration, parameters) |
|
|
|
return { |
|
"generated_audio": audio_data.tolist(), |
|
"sample_rate": self.sample_rate, |
|
"prompt": prompt, |
|
"duration": duration, |
|
"parameters": parameters |
|
} |
|
|
|
except Exception as e: |
|
print(f"ERROR: Request processing failed: {e}") |
|
return {"error": str(e)} |
|
|
|
def generate_music(self, prompt: str, duration: float, parameters: Dict) -> np.ndarray: |
|
"""Generate music with proper continuation for long sequences""" |
|
try: |
|
|
|
segment_duration = min(30, duration) |
|
overlap = 5 |
|
|
|
|
|
generation_params = { |
|
"use_sampling": parameters.get("use_sampling", True), |
|
"top_k": parameters.get("top_k", 250), |
|
"top_p": parameters.get("top_p", 0.0), |
|
"temperature": parameters.get("temperature", 1.0), |
|
"duration": segment_duration, |
|
"cfg_coef": parameters.get("cfg_coef", 3.0) |
|
} |
|
|
|
|
|
self.model.set_generation_params(**generation_params) |
|
|
|
|
|
if isinstance(prompt, str): |
|
prompt = [prompt] |
|
|
|
|
|
segment = self.model.generate(prompt, progress=False) |
|
|
|
|
|
if duration <= segment_duration: |
|
|
|
max_samples = int(duration * self.sample_rate) |
|
if segment.shape[2] > max_samples: |
|
segment = segment[:, :, :max_samples] |
|
audio_data = segment.detach().cpu().float()[0].numpy() |
|
return audio_data |
|
|
|
|
|
remaining_duration = duration - segment_duration + overlap |
|
segment_count = 1 |
|
|
|
|
|
while remaining_duration > 0: |
|
|
|
if remaining_duration < segment_duration - overlap: |
|
current_segment_duration = remaining_duration + overlap |
|
self.model.set_generation_params( |
|
use_sampling=generation_params["use_sampling"], |
|
top_k=generation_params["top_k"], |
|
top_p=generation_params["top_p"], |
|
temperature=generation_params["temperature"], |
|
duration=current_segment_duration, |
|
cfg_coef=generation_params["cfg_coef"] |
|
) |
|
|
|
|
|
last_seconds = segment[:, :, -overlap*self.sample_rate:] |
|
|
|
|
|
next_segment = self.model.generate_continuation( |
|
last_seconds, |
|
self.sample_rate, |
|
prompt, |
|
progress=False |
|
) |
|
|
|
|
|
segment = torch.cat([segment[:, :, :-overlap*self.sample_rate], next_segment], 2) |
|
|
|
|
|
if remaining_duration < segment_duration - overlap: |
|
remaining_duration = 0 |
|
else: |
|
remaining_duration -= (segment_duration - overlap) |
|
|
|
segment_count += 1 |
|
|
|
|
|
max_samples = int(duration * self.sample_rate) |
|
if segment.shape[2] > max_samples: |
|
segment = segment[:, :, :max_samples] |
|
|
|
|
|
audio_data = segment.detach().cpu().float()[0].numpy() |
|
|
|
return audio_data |
|
|
|
except Exception as e: |
|
print(f"ERROR: Music generation failed: {e}") |
|
raise |