Phoenixak99's picture
Update handler.py
c5073cf verified
from typing import Dict, Any
import torch
import numpy as np
import json
class EndpointHandler:
def __init__(self, path=""):
"""Initialize model on startup"""
try:
from audiocraft.models import MusicGen
# Load model - using melody model which supports text and melody inputs
self.model = MusicGen.get_pretrained('melody')
self.sample_rate = self.model.sample_rate
# Set default generation parameters
self.model.set_generation_params(
use_sampling=True,
top_k=250,
duration=30 # Default segment length
)
except Exception as e:
# Keep critical error logging only
print(f"CRITICAL: Failed to initialize model: {e}")
raise
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Handle prediction requests"""
try:
# Parse request data
inputs = data.get("inputs", {})
# Extract prompt and duration correctly
if isinstance(inputs, dict):
prompt = inputs.get("prompt", "")
request_duration = float(inputs.get("duration", 10.0))
else:
prompt = inputs
request_duration = 10.0
parameters = data.get("parameters", {})
# If there's a duration in parameters, it overrides the one in inputs
if "duration" in parameters:
request_duration = float(parameters.get("duration", request_duration))
# Cap duration for safety (5 minutes max - change as needed)
duration = min(request_duration, 300.0)
# Validate inputs
if not prompt:
return {"error": "No prompt provided"}
# Generate music
audio_data = self.generate_music(prompt, duration, parameters)
return {
"generated_audio": audio_data.tolist(),
"sample_rate": self.sample_rate,
"prompt": prompt,
"duration": duration,
"parameters": parameters
}
except Exception as e:
print(f"ERROR: Request processing failed: {e}")
return {"error": str(e)}
def generate_music(self, prompt: str, duration: float, parameters: Dict) -> np.ndarray:
"""Generate music with proper continuation for long sequences"""
try:
# Generation parameters
segment_duration = min(30, duration) # Max segment length (30s)
overlap = 5 # Overlap between segments in seconds
# Set specific parameters if provided
generation_params = {
"use_sampling": parameters.get("use_sampling", True),
"top_k": parameters.get("top_k", 250),
"top_p": parameters.get("top_p", 0.0),
"temperature": parameters.get("temperature", 1.0),
"duration": segment_duration,
"cfg_coef": parameters.get("cfg_coef", 3.0)
}
# Set generation parameters
self.model.set_generation_params(**generation_params)
# Handle prompt as list or string
if isinstance(prompt, str):
prompt = [prompt]
# Generate first segment
segment = self.model.generate(prompt, progress=False) # Disabled progress tracking
# If duration is less than or equal to segment_duration, we're done
if duration <= segment_duration:
# Trim to exact requested duration if needed
max_samples = int(duration * self.sample_rate)
if segment.shape[2] > max_samples:
segment = segment[:, :, :max_samples]
audio_data = segment.detach().cpu().float()[0].numpy()
return audio_data
# Track remaining duration for multi-segment generation
remaining_duration = duration - segment_duration + overlap
segment_count = 1
# Continue generating segments until we reach desired duration
while remaining_duration > 0:
# Adjust segment duration for last segment if needed
if remaining_duration < segment_duration - overlap:
current_segment_duration = remaining_duration + overlap
self.model.set_generation_params(
use_sampling=generation_params["use_sampling"],
top_k=generation_params["top_k"],
top_p=generation_params["top_p"],
temperature=generation_params["temperature"],
duration=current_segment_duration,
cfg_coef=generation_params["cfg_coef"]
)
# Extract last few seconds of current segment for continuation
last_seconds = segment[:, :, -overlap*self.sample_rate:]
# Generate continuation
next_segment = self.model.generate_continuation(
last_seconds,
self.sample_rate,
prompt,
progress=False # Disabled progress tracking
)
# Join segments (removing overlap from first segment)
segment = torch.cat([segment[:, :, :-overlap*self.sample_rate], next_segment], 2)
# Update remaining duration
if remaining_duration < segment_duration - overlap:
remaining_duration = 0
else:
remaining_duration -= (segment_duration - overlap)
segment_count += 1
# Trim to exact requested duration if needed
max_samples = int(duration * self.sample_rate)
if segment.shape[2] > max_samples:
segment = segment[:, :, :max_samples]
# Convert to numpy array
audio_data = segment.detach().cpu().float()[0].numpy()
return audio_data
except Exception as e:
print(f"ERROR: Music generation failed: {e}")
raise