Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import io | |
import gc | |
import math | |
import time | |
import uuid | |
import json | |
import spaces | |
import random | |
from abc import ABC, abstractmethod | |
from dataclasses import dataclass, field, asdict | |
from typing import Dict, List, Tuple, Optional, Any, Union | |
from enum import Enum | |
import gradio as gr | |
import numpy as np | |
import torch | |
from transformers import AutoModel, AutoTokenizer | |
import mido | |
from mido import Message, MidiFile, MidiTrack | |
# Configuration Classes | |
class ComputeMode(Enum): | |
"""Enum for computation modes.""" | |
FULL_MODEL = "Full model" | |
MOCK_LATENTS = "Mock latents" | |
class MusicRole(Enum): | |
"""Enum for musical roles/layers.""" | |
MELODY = "melody" | |
BASS = "bass" | |
HARMONY = "harmony" | |
PAD = "pad" | |
ACCENT = "accent" | |
ATMOSPHERE = "atmosphere" | |
class ScaleDefinition: | |
"""Represents a musical scale.""" | |
name: str | |
notes: List[int] | |
description: str = "" | |
def __post_init__(self): | |
"""Validate scale notes are within MIDI range.""" | |
for note in self.notes: | |
if not 0 <= note <= 127: | |
raise ValueError(f"MIDI note {note} out of range (0-127)") | |
class InstrumentMapping: | |
"""Maps a layer to an instrument and musical role.""" | |
program: int # MIDI program number | |
role: MusicRole | |
channel: int | |
name: str = "" | |
def __post_init__(self): | |
"""Validate MIDI program and channel.""" | |
if not 0 <= self.program <= 127: | |
raise ValueError(f"MIDI program {self.program} out of range") | |
if not 0 <= self.channel <= 15: | |
raise ValueError(f"MIDI channel {self.channel} out of range") | |
class GenerationConfig: | |
"""Complete configuration for music generation.""" | |
model_name: str | |
compute_mode: ComputeMode | |
base_tempo: int | |
velocity_range: Tuple[int, int] | |
scale: ScaleDefinition | |
num_layers_limit: int | |
seed: int | |
instrument_preset: str | |
# Additional configuration options | |
quantization_grid: int = 120 | |
octave_range: int = 2 | |
dynamics_curve: str = "linear" # linear, exponential, logarithmic | |
def validate(self): | |
"""Validate configuration parameters.""" | |
if not 1 <= self.base_tempo <= 2000: | |
raise ValueError("Tempo must be between 1 and 2000") | |
if not 1 <= self.velocity_range[0] < self.velocity_range[1] <= 127: | |
raise ValueError("Invalid velocity range") | |
if not 1 <= self.num_layers_limit <= 32: | |
raise ValueError("Number of layers must be between 1 and 32") | |
def to_dict(self) -> Dict: | |
"""Convert config to dictionary for serialization.""" | |
return { | |
"model_name": self.model_name, | |
"compute_mode": self.compute_mode.value, | |
"base_tempo": self.base_tempo, | |
"velocity_range": self.velocity_range, | |
"scale_name": self.scale.name, | |
"scale_notes": self.scale.notes, | |
"num_layers_limit": self.num_layers_limit, | |
"seed": self.seed, | |
"instrument_preset": self.instrument_preset, | |
"quantization_grid": self.quantization_grid, | |
"octave_range": self.octave_range, | |
"dynamics_curve": self.dynamics_curve | |
} | |
def from_dict(cls, data: Dict, scale_manager: "ScaleManager") -> "GenerationConfig": | |
"""Create config from dictionary.""" | |
scale = scale_manager.get_scale(data["scale_name"]) | |
if scale is None: | |
scale = ScaleDefinition(name="Custom", notes=data["scale_notes"]) | |
return cls( | |
model_name=data["model_name"], | |
compute_mode=ComputeMode(data["compute_mode"]), | |
base_tempo=data["base_tempo"], | |
velocity_range=tuple(data["velocity_range"]), | |
scale=scale, | |
num_layers_limit=data["num_layers_limit"], | |
seed=data["seed"], | |
instrument_preset=data["instrument_preset"], | |
quantization_grid=data.get("quantization_grid", 120), | |
octave_range=data.get("octave_range", 2), | |
dynamics_curve=data.get("dynamics_curve", "linear") | |
) | |
class Latents: | |
"""Container for model latents.""" | |
hidden_states: List[torch.Tensor] | |
attentions: List[torch.Tensor] | |
num_layers: int | |
num_tokens: int | |
metadata: Dict[str, Any] = field(default_factory=dict) | |
# Music Components | |
class ScaleManager: | |
"""Manages musical scales and modes.""" | |
def __init__(self): | |
"""Initialize with default scales.""" | |
self.scales = { | |
"C pentatonic": ScaleDefinition( | |
"C pentatonic", | |
[60, 62, 65, 67, 70, 72, 74, 77], | |
"Major pentatonic scale" | |
), | |
"C major": ScaleDefinition( | |
"C major", | |
[60, 62, 64, 65, 67, 69, 71, 72], | |
"Major scale (Ionian mode)" | |
), | |
"A minor": ScaleDefinition( | |
"A minor", | |
[57, 59, 60, 62, 64, 65, 67, 69], | |
"Natural minor scale (Aeolian mode)" | |
), | |
"D dorian": ScaleDefinition( | |
"D dorian", | |
[62, 64, 65, 67, 69, 71, 72, 74], | |
"Dorian mode - minor with raised 6th" | |
), | |
"E phrygian": ScaleDefinition( | |
"E phrygian", | |
[64, 65, 67, 69, 71, 72, 74, 76], | |
"Phrygian mode - minor with lowered 2nd" | |
), | |
"G mixolydian": ScaleDefinition( | |
"G mixolydian", | |
[67, 69, 71, 72, 74, 76, 77, 79], | |
"Mixolydian mode - major with lowered 7th" | |
), | |
"Blues scale": ScaleDefinition( | |
"Blues scale", | |
[60, 63, 65, 66, 67, 70, 72, 75], | |
"Blues scale with blue notes" | |
), | |
"Chromatic": ScaleDefinition( | |
"Chromatic", | |
list(range(60, 72)), | |
"All 12 semitones" | |
) | |
} | |
def get_scale(self, name: str) -> Optional[ScaleDefinition]: | |
"""Get scale by name.""" | |
return self.scales.get(name) | |
def add_custom_scale(self, name: str, notes: List[int], description: str = "") -> ScaleDefinition: | |
"""Add a custom scale.""" | |
scale = ScaleDefinition(name, notes, description) | |
self.scales[name] = scale | |
return scale | |
def list_scales(self) -> List[str]: | |
"""Get list of available scale names.""" | |
return list(self.scales.keys()) | |
class InstrumentPresetManager: | |
"""Manages instrument presets for different musical styles.""" | |
def __init__(self): | |
"""Initialize with default presets.""" | |
self.presets = { | |
"Ensemble (melody+bass+pad etc.)": [ | |
InstrumentMapping(0, MusicRole.MELODY, 0, "Piano"), | |
InstrumentMapping(33, MusicRole.BASS, 1, "Electric Bass"), | |
InstrumentMapping(46, MusicRole.HARMONY, 2, "Harp"), | |
InstrumentMapping(48, MusicRole.PAD, 3, "String Ensemble"), | |
InstrumentMapping(11, MusicRole.ACCENT, 4, "Vibraphone"), | |
InstrumentMapping(89, MusicRole.ATMOSPHERE, 5, "Pad Warm") | |
], | |
"Piano Trio (melody+bass+harmony)": [ | |
InstrumentMapping(0, MusicRole.MELODY, 0, "Piano"), | |
InstrumentMapping(33, MusicRole.BASS, 1, "Electric Bass"), | |
InstrumentMapping(0, MusicRole.HARMONY, 2, "Piano"), | |
InstrumentMapping(48, MusicRole.PAD, 3, "String Ensemble"), | |
InstrumentMapping(0, MusicRole.ACCENT, 4, "Piano"), | |
InstrumentMapping(0, MusicRole.ATMOSPHERE, 5, "Piano") | |
], | |
"Pads & Atmosphere": [ | |
InstrumentMapping(48, MusicRole.PAD, 0, "String Ensemble"), | |
InstrumentMapping(48, MusicRole.PAD, 1, "String Ensemble"), | |
InstrumentMapping(89, MusicRole.ATMOSPHERE, 2, "Pad Warm"), | |
InstrumentMapping(89, MusicRole.ATMOSPHERE, 3, "Pad Warm"), | |
InstrumentMapping(46, MusicRole.HARMONY, 4, "Harp"), | |
InstrumentMapping(11, MusicRole.ACCENT, 5, "Vibraphone") | |
], | |
"Orchestral": [ | |
InstrumentMapping(40, MusicRole.MELODY, 0, "Violin"), | |
InstrumentMapping(42, MusicRole.BASS, 1, "Cello"), | |
InstrumentMapping(46, MusicRole.HARMONY, 2, "Harp"), | |
InstrumentMapping(48, MusicRole.PAD, 3, "String Ensemble"), | |
InstrumentMapping(73, MusicRole.ACCENT, 4, "Flute"), | |
InstrumentMapping(49, MusicRole.ATMOSPHERE, 5, "Slow Strings") | |
], | |
"Electronic": [ | |
InstrumentMapping(80, MusicRole.MELODY, 0, "Lead Square"), | |
InstrumentMapping(38, MusicRole.BASS, 1, "Synth Bass"), | |
InstrumentMapping(81, MusicRole.HARMONY, 2, "Lead Sawtooth"), | |
InstrumentMapping(90, MusicRole.PAD, 3, "Pad Polysynth"), | |
InstrumentMapping(82, MusicRole.ACCENT, 4, "Lead Calliope"), | |
InstrumentMapping(91, MusicRole.ATMOSPHERE, 5, "Pad Bowed") | |
] | |
} | |
def get_preset(self, name: str) -> List[InstrumentMapping]: | |
"""Get instrument preset by name.""" | |
return self.presets.get(name, self.presets["Ensemble (melody+bass+pad etc.)"]) | |
def list_presets(self) -> List[str]: | |
"""Get list of available preset names.""" | |
return list(self.presets.keys()) | |
# Music Generation Components | |
class MusicMathUtils: | |
"""Utility class for music-related mathematical operations.""" | |
def entropy(p: np.ndarray) -> float: | |
"""Calculate Shannon entropy of a probability distribution.""" | |
p = p / (p.sum() + 1e-9) | |
return float(-np.sum(p * np.log2(p + 1e-9))) | |
def quantize_time(time_val: int, grid: int = 120) -> int: | |
"""Quantize time value to grid.""" | |
return int(round(time_val / grid) * grid) | |
def norm_to_scale(val: float, scale: np.ndarray, octave_range: int = 2) -> int: | |
"""Map normalized value to scale note with octave range.""" | |
octave = int(abs(val) * octave_range) * 12 | |
note_idx = int(abs(val * 100) % len(scale)) | |
return int(scale[note_idx] + octave) | |
def apply_dynamics_curve(value: float, curve_type: str = "linear") -> float: | |
"""Apply dynamics curve to a value.""" | |
value = np.clip(value, 0, 1) | |
if curve_type == "exponential": | |
return value ** 2 | |
elif curve_type == "logarithmic": | |
return np.log1p(value * np.e) / np.log1p(np.e) | |
else: # linear | |
return value | |
class NoteGenerator: | |
"""Generates notes based on neural network latents.""" | |
# Role-specific frequency multipliers | |
ROLE_FREQUENCIES = { | |
MusicRole.MELODY: 2.0, | |
MusicRole.BASS: 0.5, | |
MusicRole.HARMONY: 1.5, | |
MusicRole.PAD: 0.25, | |
MusicRole.ACCENT: 3.0, | |
MusicRole.ATMOSPHERE: 0.33 | |
} | |
# Role-specific weight distributions | |
ROLE_WEIGHTS = { | |
MusicRole.MELODY: np.array([0.4, 0.2, 0.2, 0.1, 0.1]), | |
MusicRole.BASS: np.array([0.1, 0.4, 0.1, 0.3, 0.1]), | |
MusicRole.HARMONY: np.array([0.2, 0.2, 0.3, 0.2, 0.1]), | |
MusicRole.PAD: np.array([0.1, 0.3, 0.1, 0.1, 0.4]), | |
MusicRole.ACCENT: np.array([0.5, 0.1, 0.2, 0.1, 0.1]), | |
MusicRole.ATMOSPHERE: np.array([0.1, 0.2, 0.1, 0.2, 0.4]) | |
} | |
def __init__(self, config: GenerationConfig): | |
"""Initialize with generation configuration.""" | |
self.config = config | |
self.math_utils = MusicMathUtils() | |
self.history: Dict[int, int] = {} | |
def create_note_probability( | |
self, | |
layer_idx: int, | |
token_idx: int, | |
attention_val: float, | |
hidden_state: np.ndarray, | |
num_tokens: int, | |
role: MusicRole | |
) -> float: | |
"""Calculate probability of playing a note based on multiple factors.""" | |
# Base probability from attention | |
base_prob = 1 / (1 + np.exp(-10 * (attention_val - 0.5))) | |
# Temporal factor based on role frequency | |
temporal_factor = 0.5 + 0.5 * np.sin( | |
2 * np.pi * self.ROLE_FREQUENCIES[role] * token_idx / max(1, num_tokens) | |
) | |
# Energy factor from hidden state norm | |
energy = np.linalg.norm(hidden_state) | |
energy_factor = np.tanh(energy / 10) | |
# Variance factor | |
local_variance = np.var(hidden_state) | |
variance_factor = 1 - np.exp(-local_variance) | |
# Entropy factor | |
state_entropy = self.math_utils.entropy(np.abs(hidden_state)) | |
max_entropy = np.log2(max(2, hidden_state.shape[0])) | |
entropy_factor = state_entropy / max_entropy | |
# Combine factors with role-specific weights | |
factors = np.array([base_prob, temporal_factor, energy_factor, variance_factor, entropy_factor]) | |
weights = self.ROLE_WEIGHTS[role] | |
combined_prob = float(np.dot(weights, factors)) | |
# Add deterministic noise for variation | |
noise_seed = layer_idx * 1000 + token_idx | |
noise = 0.1 * (np.sin(noise_seed * 0.1) + np.cos(noise_seed * 0.23)) / 2 | |
# Apply dynamics curve | |
final_prob = (combined_prob + noise) ** 1.5 | |
final_prob = self.math_utils.apply_dynamics_curve(final_prob, self.config.dynamics_curve) | |
return float(np.clip(final_prob, 0, 1)) | |
def should_play_note( | |
self, | |
layer_idx: int, | |
token_idx: int, | |
attention_val: float, | |
hidden_state: np.ndarray, | |
num_tokens: int, | |
role: MusicRole | |
) -> bool: | |
"""Determine if a note should be played.""" | |
prob = self.create_note_probability( | |
layer_idx, token_idx, attention_val, hidden_state, num_tokens, role | |
) | |
# Adjust probability based on silence duration | |
if layer_idx in self.history: | |
last_played = self.history[layer_idx] | |
silence_duration = token_idx - last_played | |
prob *= (1 + np.tanh(silence_duration / 5) * 0.5) | |
# Stochastic decision | |
play_note = np.random.random() < prob | |
if play_note: | |
self.history[layer_idx] = token_idx | |
return play_note | |
def generate_notes_for_role( | |
self, | |
role: MusicRole, | |
hidden_state: np.ndarray, | |
scale: np.ndarray | |
) -> List[int]: | |
"""Generate notes based on role and hidden state.""" | |
if role == MusicRole.MELODY: | |
note = self.math_utils.norm_to_scale( | |
hidden_state[0], scale, octave_range=1 | |
) | |
return [note] | |
elif role == MusicRole.BASS: | |
note = self.math_utils.norm_to_scale( | |
hidden_state[0], scale, octave_range=0 | |
) - 12 | |
return [note] | |
elif role == MusicRole.HARMONY: | |
return [ | |
self.math_utils.norm_to_scale(hidden_state[i], scale, octave_range=1) | |
for i in range(0, min(2, len(hidden_state)), 1) | |
] | |
elif role == MusicRole.PAD: | |
return [ | |
self.math_utils.norm_to_scale(hidden_state[i], scale, octave_range=1) | |
for i in range(0, min(3, len(hidden_state)), 2) | |
] | |
elif role == MusicRole.ACCENT: | |
note = self.math_utils.norm_to_scale( | |
hidden_state[0], scale, octave_range=2 | |
) + 12 | |
return [note] | |
else: # ATMOSPHERE | |
return [ | |
self.math_utils.norm_to_scale(hidden_state[i], scale, octave_range=1) | |
for i in range(0, min(2, len(hidden_state)), 3) | |
] | |
def calculate_velocity( | |
self, | |
role: MusicRole, | |
attention_strength: float | |
) -> int: | |
"""Calculate note velocity based on role and attention.""" | |
base_velocity = int( | |
attention_strength * (self.config.velocity_range[1] - self.config.velocity_range[0]) | |
+ self.config.velocity_range[0] | |
) | |
# Role-specific adjustments | |
if role == MusicRole.MELODY: | |
velocity = min(base_velocity + 10, 127) | |
elif role == MusicRole.ACCENT: | |
velocity = min(base_velocity + 20, 127) | |
elif role in [MusicRole.PAD, MusicRole.ATMOSPHERE]: | |
velocity = max(base_velocity - 10, 20) | |
else: | |
velocity = base_velocity | |
return velocity | |
def calculate_duration( | |
self, | |
role: MusicRole, | |
attention_matrix: np.ndarray | |
) -> int: | |
"""Calculate note duration based on role and attention.""" | |
if role in [MusicRole.PAD, MusicRole.ATMOSPHERE]: | |
duration = self.config.base_tempo * 4 | |
elif role == MusicRole.BASS: | |
duration = self.config.base_tempo | |
else: | |
try: | |
dur_factor = self.math_utils.entropy(attention_matrix.mean(axis=0)) / ( | |
np.log2(attention_matrix.shape[-1]) + 1e-9 | |
) | |
except Exception: | |
dur_factor = 0.5 | |
duration = self.math_utils.quantize_time( | |
int(self.config.base_tempo * (0.5 + dur_factor * 1.5)), | |
self.config.quantization_grid | |
) | |
return duration | |
# Model Interaction | |
class LatentExtractor(ABC): | |
"""Abstract base class for latent extraction strategies.""" | |
def extract(self, text: str, config: GenerationConfig, progress=None) -> Latents: | |
"""Extract latents from text.""" | |
pass | |
class MockLatentExtractor(LatentExtractor): | |
"""Generate mock latents for testing without loading models.""" | |
def extract(self, text: str, config: GenerationConfig, progress=None) -> Latents: | |
"""Generate synthetic latents based on text.""" | |
# Simulate token count based on text length | |
tokens = max(16, min(128, len(text.split()) * 4)) | |
layers = min(config.num_layers_limit, 6) | |
# Generate deterministic but varied latents based on text | |
np.random.seed(hash(text) % 2**32) | |
hidden_states = [ | |
torch.randn(1, tokens, 128) for _ in range(layers) | |
] | |
attentions = [ | |
torch.rand(1, 8, tokens, tokens) for _ in range(layers) | |
] | |
metadata = { | |
"mode": "mock", | |
"text_length": len(text), | |
"generated_tokens": tokens, | |
"generated_layers": layers | |
} | |
return Latents( | |
hidden_states=hidden_states, | |
attentions=attentions, | |
num_layers=layers, | |
num_tokens=tokens, | |
metadata=metadata | |
) | |
class ModelLatentExtractor(LatentExtractor): | |
"""Extract real latents from transformer models.""" | |
def extract(self, text: str, config: GenerationConfig, progress=None) -> Latents: | |
"""Extract latents from a real transformer model.""" | |
model_name = config.model_name | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
if tokenizer.pad_token is None and tokenizer.eos_token is not None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Configure model loading | |
load_kwargs = { | |
"output_hidden_states": True, | |
"output_attentions": True, | |
"device_map": "cuda" if torch.cuda.is_available() else "cpu", | |
} | |
# Set appropriate dtype | |
try: | |
load_kwargs["torch_dtype"] = ( | |
torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
) | |
except Exception: | |
pass | |
# Load model | |
model = AutoModel.from_pretrained(model_name, **load_kwargs) | |
# Tokenize input | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
device = next(model.parameters()).device | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
# Get model outputs | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
hidden_states = list(outputs.hidden_states) | |
attentions = list(outputs.attentions) | |
# Move to CPU to free VRAM | |
hidden_states = [hs.to("cpu") for hs in hidden_states] | |
attentions = [att.to("cpu") for att in attentions] | |
# Limit layers | |
layers = min(config.num_layers_limit, len(hidden_states)) | |
tokens = hidden_states[0].shape[1] | |
# Clean up | |
try: | |
del model | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
except Exception: | |
pass | |
metadata = { | |
"mode": "full_model", | |
"model_name": model_name, | |
"actual_layers": len(hidden_states), | |
"used_layers": layers, | |
"tokens": tokens | |
} | |
return Latents( | |
hidden_states=hidden_states[:layers], | |
attentions=attentions[:layers], | |
num_layers=layers, | |
num_tokens=tokens, | |
metadata=metadata | |
) | |
class LatentExtractorFactory: | |
"""Factory for creating appropriate latent extractors.""" | |
def create(compute_mode: ComputeMode) -> LatentExtractor: | |
"""Create a latent extractor based on compute mode.""" | |
if compute_mode == ComputeMode.MOCK_LATENTS: | |
return MockLatentExtractor() | |
else: | |
return ModelLatentExtractor() | |
# MIDI Generation | |
class MIDIRenderer: | |
"""Renders MIDI files from latents.""" | |
def __init__(self, config: GenerationConfig, instrument_manager: InstrumentPresetManager): | |
"""Initialize MIDI renderer.""" | |
self.config = config | |
self.instrument_manager = instrument_manager | |
self.note_generator = NoteGenerator(config) | |
self.math_utils = MusicMathUtils() | |
def render(self, latents: Latents) -> Tuple[bytes, Dict[str, Any]]: | |
"""Render MIDI from latents.""" | |
# Set random seeds for reproducibility | |
np.random.seed(self.config.seed) | |
random.seed(self.config.seed) | |
torch.manual_seed(self.config.seed) | |
# Prepare data | |
scale = np.array(self.config.scale.notes, dtype=int) | |
num_layers = latents.num_layers | |
num_tokens = latents.num_tokens | |
# Convert tensors to numpy | |
hidden_states = [ | |
hs.float().numpy() if isinstance(hs, torch.Tensor) else hs | |
for hs in latents.hidden_states | |
] | |
attentions = [ | |
att.float().numpy() if isinstance(att, torch.Tensor) else att | |
for att in latents.attentions | |
] | |
# Get instrument mappings | |
instrument_mappings = self.instrument_manager.get_preset(self.config.instrument_preset) | |
# Create MIDI file and tracks | |
midi_file = MidiFile() | |
tracks = self._create_tracks(midi_file, num_layers, instrument_mappings) | |
# Generate notes | |
stats = self._generate_notes( | |
tracks, hidden_states, attentions, | |
scale, num_tokens, instrument_mappings | |
) | |
# Convert to bytes | |
bio = io.BytesIO() | |
midi_file.save(file=bio) | |
bio.seek(0) | |
# Prepare metadata | |
metadata = { | |
"config": self.config.to_dict(), | |
"latents_info": latents.metadata, | |
"stats": stats, | |
"timestamp": time.time() | |
} | |
return bio.read(), metadata | |
def _create_tracks( | |
self, | |
midi_file: MidiFile, | |
num_layers: int, | |
instrument_mappings: List[InstrumentMapping] | |
) -> List[MidiTrack]: | |
"""Create MIDI tracks with instrument assignments.""" | |
tracks = [] | |
for layer_idx in range(num_layers): | |
track = MidiTrack() | |
midi_file.tracks.append(track) | |
tracks.append(track) | |
# Get instrument mapping for this layer | |
if layer_idx < len(instrument_mappings): | |
mapping = instrument_mappings[layer_idx] | |
else: | |
# Default to piano if not enough mappings | |
mapping = InstrumentMapping(0, MusicRole.MELODY, layer_idx % 16) | |
# Set instrument | |
track.append(Message( | |
"program_change", | |
program=mapping.program, | |
time=0, | |
channel=mapping.channel | |
)) | |
# Add track name | |
if mapping.name: | |
track.append(mido.MetaMessage( | |
"track_name", | |
name=f"{mapping.name} - {mapping.role.value}", | |
time=0 | |
)) | |
return tracks | |
def _generate_notes( | |
self, | |
tracks: List[MidiTrack], | |
hidden_states: List[np.ndarray], | |
attentions: List[np.ndarray], | |
scale: np.ndarray, | |
num_tokens: int, | |
instrument_mappings: List[InstrumentMapping] | |
) -> Dict[str, Any]: | |
"""Generate notes for all tracks.""" | |
current_time = [0] * len(tracks) | |
notes_count = [0] * len(tracks) | |
for token_idx in range(num_tokens): | |
# Update time periodically | |
if token_idx > 0 and token_idx % 4 == 0: | |
for layer_idx in range(len(tracks)): | |
current_time[layer_idx] += self.config.base_tempo | |
# Calculate panning | |
pan = 64 + int(32 * np.sin(token_idx * math.pi / max(1, num_tokens))) | |
# Generate notes for each layer | |
for layer_idx in range(len(tracks)): | |
if layer_idx >= len(instrument_mappings): | |
continue | |
mapping = instrument_mappings[layer_idx] | |
# Get attention and hidden state | |
attn_matrix = attentions[min(layer_idx, len(attentions) - 1)][0, :, token_idx, :] | |
attention_strength = float(np.mean(attn_matrix)) | |
layer_vec = hidden_states[layer_idx][0, token_idx] | |
# Check if note should be played | |
if not self.note_generator.should_play_note( | |
layer_idx, token_idx, attention_strength, | |
layer_vec, num_tokens, mapping.role | |
): | |
continue | |
# Generate notes | |
notes_to_play = self.note_generator.generate_notes_for_role( | |
mapping.role, layer_vec, scale | |
) | |
# Calculate velocity and duration | |
velocity = self.note_generator.calculate_velocity( | |
mapping.role, attention_strength | |
) | |
duration = self.note_generator.calculate_duration( | |
mapping.role, attn_matrix | |
) | |
# Add notes to track | |
for note in notes_to_play: | |
note = max(21, min(108, int(note))) # Clamp to piano range | |
tracks[layer_idx].append(Message( | |
"note_on", | |
note=note, | |
velocity=velocity, | |
time=current_time[layer_idx], | |
channel=mapping.channel | |
)) | |
tracks[layer_idx].append(Message( | |
"note_off", | |
note=note, | |
velocity=0, | |
time=duration, | |
channel=mapping.channel | |
)) | |
current_time[layer_idx] = 0 | |
notes_count[layer_idx] += 1 | |
# Set panning on first token | |
if token_idx == 0: | |
tracks[layer_idx].append(Message( | |
"control_change", | |
control=10, | |
value=pan, | |
time=0, | |
channel=mapping.channel | |
)) | |
return { | |
"num_layers": len(tracks), | |
"num_tokens": num_tokens, | |
"notes_per_layer": notes_count, | |
"total_notes": int(sum(notes_count)), | |
"tempo_ticks_per_beat": int(self.config.base_tempo), | |
"scale": list(map(int, scale.tolist())), | |
} | |
# Main Orchestrator | |
class LLMForestOrchestra: | |
"""Main orchestrator class that coordinates the entire pipeline.""" | |
DEFAULT_MODEL = "unsloth/Qwen3-14B-Base" | |
def __init__(self): | |
"""Initialize the orchestra.""" | |
self.scale_manager = ScaleManager() | |
self.instrument_manager = InstrumentPresetManager() | |
self.saved_configs: Dict[str, GenerationConfig] = {} | |
def generate( | |
self, | |
text: str, | |
model_name: str, | |
compute_mode: str, | |
base_tempo: int, | |
velocity_range: Tuple[int, int], | |
scale_name: str, | |
custom_scale_notes: Optional[List[int]], | |
num_layers: int, | |
instrument_preset: str, | |
seed: int, | |
quantization_grid: int = 120, | |
octave_range: int = 2, | |
dynamics_curve: str = "linear" | |
) -> Tuple[str, Dict[str, Any]]: | |
"""Generate MIDI from text input.""" | |
# Get or create scale | |
if scale_name == "Custom": | |
if not custom_scale_notes: | |
raise ValueError("Custom scale requires note list") | |
scale = ScaleDefinition("Custom", custom_scale_notes) | |
else: | |
scale = self.scale_manager.get_scale(scale_name) | |
if scale is None: | |
raise ValueError(f"Unknown scale: {scale_name}") | |
# Create configuration | |
config = GenerationConfig( | |
model_name=model_name or self.DEFAULT_MODEL, | |
compute_mode=ComputeMode(compute_mode), | |
base_tempo=base_tempo, | |
velocity_range=velocity_range, | |
scale=scale, | |
num_layers_limit=num_layers, | |
seed=seed, | |
instrument_preset=instrument_preset, | |
quantization_grid=quantization_grid, | |
octave_range=octave_range, | |
dynamics_curve=dynamics_curve | |
) | |
# Validate configuration | |
config.validate() | |
# Extract latents | |
extractor = LatentExtractorFactory.create(config.compute_mode) | |
latents = extractor.extract(text, config) | |
# Render MIDI | |
renderer = MIDIRenderer(config, self.instrument_manager) | |
midi_bytes, metadata = renderer.render(latents) | |
# Save MIDI file | |
filename = f"llm_forest_orchestra_{uuid.uuid4().hex[:8]}.mid" | |
with open(filename, "wb") as f: | |
f.write(midi_bytes) | |
return filename, metadata | |
def save_config(self, name: str, config: GenerationConfig): | |
"""Save a configuration for later use.""" | |
self.saved_configs[name] = config | |
def load_config(self, name: str) -> Optional[GenerationConfig]: | |
"""Load a saved configuration.""" | |
return self.saved_configs.get(name) | |
def export_config(self, config: GenerationConfig, filepath: str): | |
"""Export configuration to JSON file.""" | |
with open(filepath, "w") as f: | |
json.dump(config.to_dict(), f, indent=2) | |
def import_config(self, filepath: str) -> GenerationConfig: | |
"""Import configuration from JSON file.""" | |
with open(filepath, "r") as f: | |
data = json.load(f) | |
return GenerationConfig.from_dict(data, self.scale_manager) | |
# Gradio UI | |
class GradioInterface: | |
"""Manages the Gradio user interface.""" | |
DESCRIPTION = """ | |
# π² LLM Forest Orchestra β Sonify Transformer Internals | |
Transform the hidden states and attention patterns of language models into multi-layered musical compositions. | |
## π Inspiration | |
This project is inspired by the way **mushrooms and mycelial networks in forests** | |
connect plants and trees, forming a living web of communication and resource sharing. | |
These connections, can be turned into ethereal music. | |
Just as signals move through these hidden connections, transformer models also | |
pass hidden states and attentions across their layers. Here, those hidden | |
connections are translated into **music**, analogous to the forest's secret orchestra. | |
## Features | |
- **Two compute modes**: Full model (GPU) or Mock latents (CPU-friendly) | |
- **Multiple musical scales**: From pentatonic to chromatic | |
- **Instrument presets**: Orchestral, electronic, ensemble, and more | |
- **Advanced controls**: Dynamics curves, quantization, velocity ranges | |
- **Export**: Standard MIDI files for further editing in your DAW | |
""" | |
EXAMPLE_TEXT = """Joy cascades in golden waterfalls, crashing into pools of melancholy blue. | |
Anger burns red through veins of marble, while serenity floats on clouds of softest grey. | |
Love pulses in waves of crimson and rose, intertwining with longing's purple haze. | |
Each feeling resonates at its own frequency, painting music across the soul's canvas.""" | |
def __init__(self, orchestra: LLMForestOrchestra): | |
"""Initialize the interface.""" | |
self.orchestra = orchestra | |
def create_interface(self) -> gr.Blocks: | |
"""Create the Gradio interface.""" | |
with gr.Blocks(title="LLM Forest Orchestra", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(self.DESCRIPTION) | |
with gr.Tabs(): | |
with gr.TabItem("π΅ Generate Music"): | |
self._create_generation_tab() | |
return demo | |
def _create_generation_tab(self): | |
"""Create the main generation tab.""" | |
with gr.Row(): | |
with gr.Column(scale=1): | |
text_input = gr.Textbox( | |
value=self.EXAMPLE_TEXT, | |
label="Input Text", | |
lines=8, | |
placeholder="Enter text to sonify..." | |
) | |
model_name = gr.Textbox( | |
value=self.orchestra.DEFAULT_MODEL, | |
label="Hugging Face Model", | |
info="Model must support output_hidden_states and output_attentions" | |
) | |
compute_mode = gr.Radio( | |
choices=["Full model", "Mock latents"], | |
value="Mock latents", | |
label="Compute Mode", | |
info="Mock latents for quick CPU-only demo" | |
) | |
with gr.Row(): | |
instrument_preset = gr.Dropdown( | |
choices=self.orchestra.instrument_manager.list_presets(), | |
value="Ensemble (melody+bass+pad etc.)", | |
label="Instrument Preset" | |
) | |
scale_choice = gr.Dropdown( | |
choices=self.orchestra.scale_manager.list_scales() + ["Custom"], | |
value="C pentatonic", | |
label="Musical Scale" | |
) | |
custom_scale = gr.Textbox( | |
value="", | |
label="Custom Scale Notes", | |
placeholder="60,62,65,67,70", | |
visible=False | |
) | |
with gr.Row(): | |
base_tempo = gr.Slider( | |
120, 960, | |
value=480, | |
step=1, | |
label="Tempo (ticks per beat)" | |
) | |
num_layers = gr.Slider( | |
1, 6, | |
value=6, | |
step=1, | |
label="Max Layers" | |
) | |
with gr.Row(): | |
velocity_low = gr.Slider( | |
1, 126, | |
value=40, | |
step=1, | |
label="Min Velocity" | |
) | |
velocity_high = gr.Slider( | |
2, 127, | |
value=90, | |
step=1, | |
label="Max Velocity" | |
) | |
seed = gr.Number( | |
value=42, | |
precision=0, | |
label="Random Seed" | |
) | |
generate_btn = gr.Button( | |
"πΌ Generate MIDI", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(scale=1): | |
midi_output = gr.File( | |
label="Generated MIDI File", | |
file_types=[".mid", ".midi"] | |
) | |
stats_display = gr.Markdown(label="Quick Stats") | |
metadata_json = gr.Code( | |
label="Metadata (JSON)", | |
language="json" | |
) | |
with gr.Row(): | |
play_instructions = gr.Markdown( | |
""" | |
### π§ How to Play | |
1. Download the MIDI file | |
2. Open in any DAW or MIDI player | |
3. Adjust instruments and effects as desired | |
4. Export to audio format | |
""" | |
) | |
# Set up interactions | |
def update_custom_scale_visibility(choice): | |
return gr.update(visible=(choice == "Custom")) | |
scale_choice.change( | |
update_custom_scale_visibility, | |
inputs=[scale_choice], | |
outputs=[custom_scale] | |
) | |
def generate_wrapper( | |
text, model_name, compute_mode, base_tempo, | |
velocity_low, velocity_high, scale_choice, | |
custom_scale, num_layers, instrument_preset, seed | |
): | |
"""Wrapper for generation with error handling.""" | |
try: | |
# Parse custom scale if needed | |
custom_notes = None | |
if scale_choice == "Custom" and custom_scale: | |
custom_notes = [int(x.strip()) for x in custom_scale.split(",")] | |
# Generate | |
filename, metadata = self.orchestra.generate( | |
text=text, | |
model_name=model_name, | |
compute_mode=compute_mode, | |
base_tempo=int(base_tempo), | |
velocity_range=(int(velocity_low), int(velocity_high)), | |
scale_name=scale_choice, | |
custom_scale_notes=custom_notes, | |
num_layers=int(num_layers), | |
instrument_preset=instrument_preset, | |
seed=int(seed) | |
) | |
# Format stats | |
stats = metadata.get("stats", {}) | |
stats_text = f""" | |
### Generation Statistics | |
- **Layers Used**: {stats.get('num_layers', 'N/A')} | |
- **Tokens Processed**: {stats.get('num_tokens', 'N/A')} | |
- **Total Notes**: {stats.get('total_notes', 'N/A')} | |
- **Notes per Layer**: {stats.get('notes_per_layer', [])} | |
- **Scale**: {stats.get('scale', [])} | |
- **Tempo**: {stats.get('tempo_ticks_per_beat', 'N/A')} ticks/beat | |
""" | |
return filename, stats_text, json.dumps(metadata, indent=2) | |
except Exception as e: | |
error_msg = f"### β Error\n{str(e)}" | |
return None, error_msg, json.dumps({"error": str(e)}, indent=2) | |
generate_btn.click( | |
fn=generate_wrapper, | |
inputs=[ | |
text_input, model_name, compute_mode, base_tempo, | |
velocity_low, velocity_high, scale_choice, | |
custom_scale, num_layers, instrument_preset, seed | |
], | |
outputs=[midi_output, stats_display, metadata_json] | |
) | |
# Main Entry Point | |
def main(): | |
"""Main entry point for the application.""" | |
# Initialize orchestra | |
orchestra = LLMForestOrchestra() | |
# Create interface | |
interface = GradioInterface(orchestra) | |
demo = interface.create_interface() | |
# Launch | |
demo.launch() | |
if __name__ == "__main__": | |
main() |