Locutusque's picture
Big refactor + more features
5048db9 verified
import os
import io
import gc
import math
import time
import uuid
import json
import spaces
import random
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, asdict
from typing import Dict, List, Tuple, Optional, Any, Union
from enum import Enum
import gradio as gr
import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer
import mido
from mido import Message, MidiFile, MidiTrack
# Configuration Classes
class ComputeMode(Enum):
"""Enum for computation modes."""
FULL_MODEL = "Full model"
MOCK_LATENTS = "Mock latents"
class MusicRole(Enum):
"""Enum for musical roles/layers."""
MELODY = "melody"
BASS = "bass"
HARMONY = "harmony"
PAD = "pad"
ACCENT = "accent"
ATMOSPHERE = "atmosphere"
@dataclass
class ScaleDefinition:
"""Represents a musical scale."""
name: str
notes: List[int]
description: str = ""
def __post_init__(self):
"""Validate scale notes are within MIDI range."""
for note in self.notes:
if not 0 <= note <= 127:
raise ValueError(f"MIDI note {note} out of range (0-127)")
@dataclass
class InstrumentMapping:
"""Maps a layer to an instrument and musical role."""
program: int # MIDI program number
role: MusicRole
channel: int
name: str = ""
def __post_init__(self):
"""Validate MIDI program and channel."""
if not 0 <= self.program <= 127:
raise ValueError(f"MIDI program {self.program} out of range")
if not 0 <= self.channel <= 15:
raise ValueError(f"MIDI channel {self.channel} out of range")
@dataclass
class GenerationConfig:
"""Complete configuration for music generation."""
model_name: str
compute_mode: ComputeMode
base_tempo: int
velocity_range: Tuple[int, int]
scale: ScaleDefinition
num_layers_limit: int
seed: int
instrument_preset: str
# Additional configuration options
quantization_grid: int = 120
octave_range: int = 2
dynamics_curve: str = "linear" # linear, exponential, logarithmic
def validate(self):
"""Validate configuration parameters."""
if not 1 <= self.base_tempo <= 2000:
raise ValueError("Tempo must be between 1 and 2000")
if not 1 <= self.velocity_range[0] < self.velocity_range[1] <= 127:
raise ValueError("Invalid velocity range")
if not 1 <= self.num_layers_limit <= 32:
raise ValueError("Number of layers must be between 1 and 32")
def to_dict(self) -> Dict:
"""Convert config to dictionary for serialization."""
return {
"model_name": self.model_name,
"compute_mode": self.compute_mode.value,
"base_tempo": self.base_tempo,
"velocity_range": self.velocity_range,
"scale_name": self.scale.name,
"scale_notes": self.scale.notes,
"num_layers_limit": self.num_layers_limit,
"seed": self.seed,
"instrument_preset": self.instrument_preset,
"quantization_grid": self.quantization_grid,
"octave_range": self.octave_range,
"dynamics_curve": self.dynamics_curve
}
@classmethod
def from_dict(cls, data: Dict, scale_manager: "ScaleManager") -> "GenerationConfig":
"""Create config from dictionary."""
scale = scale_manager.get_scale(data["scale_name"])
if scale is None:
scale = ScaleDefinition(name="Custom", notes=data["scale_notes"])
return cls(
model_name=data["model_name"],
compute_mode=ComputeMode(data["compute_mode"]),
base_tempo=data["base_tempo"],
velocity_range=tuple(data["velocity_range"]),
scale=scale,
num_layers_limit=data["num_layers_limit"],
seed=data["seed"],
instrument_preset=data["instrument_preset"],
quantization_grid=data.get("quantization_grid", 120),
octave_range=data.get("octave_range", 2),
dynamics_curve=data.get("dynamics_curve", "linear")
)
@dataclass
class Latents:
"""Container for model latents."""
hidden_states: List[torch.Tensor]
attentions: List[torch.Tensor]
num_layers: int
num_tokens: int
metadata: Dict[str, Any] = field(default_factory=dict)
# Music Components
class ScaleManager:
"""Manages musical scales and modes."""
def __init__(self):
"""Initialize with default scales."""
self.scales = {
"C pentatonic": ScaleDefinition(
"C pentatonic",
[60, 62, 65, 67, 70, 72, 74, 77],
"Major pentatonic scale"
),
"C major": ScaleDefinition(
"C major",
[60, 62, 64, 65, 67, 69, 71, 72],
"Major scale (Ionian mode)"
),
"A minor": ScaleDefinition(
"A minor",
[57, 59, 60, 62, 64, 65, 67, 69],
"Natural minor scale (Aeolian mode)"
),
"D dorian": ScaleDefinition(
"D dorian",
[62, 64, 65, 67, 69, 71, 72, 74],
"Dorian mode - minor with raised 6th"
),
"E phrygian": ScaleDefinition(
"E phrygian",
[64, 65, 67, 69, 71, 72, 74, 76],
"Phrygian mode - minor with lowered 2nd"
),
"G mixolydian": ScaleDefinition(
"G mixolydian",
[67, 69, 71, 72, 74, 76, 77, 79],
"Mixolydian mode - major with lowered 7th"
),
"Blues scale": ScaleDefinition(
"Blues scale",
[60, 63, 65, 66, 67, 70, 72, 75],
"Blues scale with blue notes"
),
"Chromatic": ScaleDefinition(
"Chromatic",
list(range(60, 72)),
"All 12 semitones"
)
}
def get_scale(self, name: str) -> Optional[ScaleDefinition]:
"""Get scale by name."""
return self.scales.get(name)
def add_custom_scale(self, name: str, notes: List[int], description: str = "") -> ScaleDefinition:
"""Add a custom scale."""
scale = ScaleDefinition(name, notes, description)
self.scales[name] = scale
return scale
def list_scales(self) -> List[str]:
"""Get list of available scale names."""
return list(self.scales.keys())
class InstrumentPresetManager:
"""Manages instrument presets for different musical styles."""
def __init__(self):
"""Initialize with default presets."""
self.presets = {
"Ensemble (melody+bass+pad etc.)": [
InstrumentMapping(0, MusicRole.MELODY, 0, "Piano"),
InstrumentMapping(33, MusicRole.BASS, 1, "Electric Bass"),
InstrumentMapping(46, MusicRole.HARMONY, 2, "Harp"),
InstrumentMapping(48, MusicRole.PAD, 3, "String Ensemble"),
InstrumentMapping(11, MusicRole.ACCENT, 4, "Vibraphone"),
InstrumentMapping(89, MusicRole.ATMOSPHERE, 5, "Pad Warm")
],
"Piano Trio (melody+bass+harmony)": [
InstrumentMapping(0, MusicRole.MELODY, 0, "Piano"),
InstrumentMapping(33, MusicRole.BASS, 1, "Electric Bass"),
InstrumentMapping(0, MusicRole.HARMONY, 2, "Piano"),
InstrumentMapping(48, MusicRole.PAD, 3, "String Ensemble"),
InstrumentMapping(0, MusicRole.ACCENT, 4, "Piano"),
InstrumentMapping(0, MusicRole.ATMOSPHERE, 5, "Piano")
],
"Pads & Atmosphere": [
InstrumentMapping(48, MusicRole.PAD, 0, "String Ensemble"),
InstrumentMapping(48, MusicRole.PAD, 1, "String Ensemble"),
InstrumentMapping(89, MusicRole.ATMOSPHERE, 2, "Pad Warm"),
InstrumentMapping(89, MusicRole.ATMOSPHERE, 3, "Pad Warm"),
InstrumentMapping(46, MusicRole.HARMONY, 4, "Harp"),
InstrumentMapping(11, MusicRole.ACCENT, 5, "Vibraphone")
],
"Orchestral": [
InstrumentMapping(40, MusicRole.MELODY, 0, "Violin"),
InstrumentMapping(42, MusicRole.BASS, 1, "Cello"),
InstrumentMapping(46, MusicRole.HARMONY, 2, "Harp"),
InstrumentMapping(48, MusicRole.PAD, 3, "String Ensemble"),
InstrumentMapping(73, MusicRole.ACCENT, 4, "Flute"),
InstrumentMapping(49, MusicRole.ATMOSPHERE, 5, "Slow Strings")
],
"Electronic": [
InstrumentMapping(80, MusicRole.MELODY, 0, "Lead Square"),
InstrumentMapping(38, MusicRole.BASS, 1, "Synth Bass"),
InstrumentMapping(81, MusicRole.HARMONY, 2, "Lead Sawtooth"),
InstrumentMapping(90, MusicRole.PAD, 3, "Pad Polysynth"),
InstrumentMapping(82, MusicRole.ACCENT, 4, "Lead Calliope"),
InstrumentMapping(91, MusicRole.ATMOSPHERE, 5, "Pad Bowed")
]
}
def get_preset(self, name: str) -> List[InstrumentMapping]:
"""Get instrument preset by name."""
return self.presets.get(name, self.presets["Ensemble (melody+bass+pad etc.)"])
def list_presets(self) -> List[str]:
"""Get list of available preset names."""
return list(self.presets.keys())
# Music Generation Components
class MusicMathUtils:
"""Utility class for music-related mathematical operations."""
@staticmethod
def entropy(p: np.ndarray) -> float:
"""Calculate Shannon entropy of a probability distribution."""
p = p / (p.sum() + 1e-9)
return float(-np.sum(p * np.log2(p + 1e-9)))
@staticmethod
def quantize_time(time_val: int, grid: int = 120) -> int:
"""Quantize time value to grid."""
return int(round(time_val / grid) * grid)
@staticmethod
def norm_to_scale(val: float, scale: np.ndarray, octave_range: int = 2) -> int:
"""Map normalized value to scale note with octave range."""
octave = int(abs(val) * octave_range) * 12
note_idx = int(abs(val * 100) % len(scale))
return int(scale[note_idx] + octave)
@staticmethod
def apply_dynamics_curve(value: float, curve_type: str = "linear") -> float:
"""Apply dynamics curve to a value."""
value = np.clip(value, 0, 1)
if curve_type == "exponential":
return value ** 2
elif curve_type == "logarithmic":
return np.log1p(value * np.e) / np.log1p(np.e)
else: # linear
return value
class NoteGenerator:
"""Generates notes based on neural network latents."""
# Role-specific frequency multipliers
ROLE_FREQUENCIES = {
MusicRole.MELODY: 2.0,
MusicRole.BASS: 0.5,
MusicRole.HARMONY: 1.5,
MusicRole.PAD: 0.25,
MusicRole.ACCENT: 3.0,
MusicRole.ATMOSPHERE: 0.33
}
# Role-specific weight distributions
ROLE_WEIGHTS = {
MusicRole.MELODY: np.array([0.4, 0.2, 0.2, 0.1, 0.1]),
MusicRole.BASS: np.array([0.1, 0.4, 0.1, 0.3, 0.1]),
MusicRole.HARMONY: np.array([0.2, 0.2, 0.3, 0.2, 0.1]),
MusicRole.PAD: np.array([0.1, 0.3, 0.1, 0.1, 0.4]),
MusicRole.ACCENT: np.array([0.5, 0.1, 0.2, 0.1, 0.1]),
MusicRole.ATMOSPHERE: np.array([0.1, 0.2, 0.1, 0.2, 0.4])
}
def __init__(self, config: GenerationConfig):
"""Initialize with generation configuration."""
self.config = config
self.math_utils = MusicMathUtils()
self.history: Dict[int, int] = {}
def create_note_probability(
self,
layer_idx: int,
token_idx: int,
attention_val: float,
hidden_state: np.ndarray,
num_tokens: int,
role: MusicRole
) -> float:
"""Calculate probability of playing a note based on multiple factors."""
# Base probability from attention
base_prob = 1 / (1 + np.exp(-10 * (attention_val - 0.5)))
# Temporal factor based on role frequency
temporal_factor = 0.5 + 0.5 * np.sin(
2 * np.pi * self.ROLE_FREQUENCIES[role] * token_idx / max(1, num_tokens)
)
# Energy factor from hidden state norm
energy = np.linalg.norm(hidden_state)
energy_factor = np.tanh(energy / 10)
# Variance factor
local_variance = np.var(hidden_state)
variance_factor = 1 - np.exp(-local_variance)
# Entropy factor
state_entropy = self.math_utils.entropy(np.abs(hidden_state))
max_entropy = np.log2(max(2, hidden_state.shape[0]))
entropy_factor = state_entropy / max_entropy
# Combine factors with role-specific weights
factors = np.array([base_prob, temporal_factor, energy_factor, variance_factor, entropy_factor])
weights = self.ROLE_WEIGHTS[role]
combined_prob = float(np.dot(weights, factors))
# Add deterministic noise for variation
noise_seed = layer_idx * 1000 + token_idx
noise = 0.1 * (np.sin(noise_seed * 0.1) + np.cos(noise_seed * 0.23)) / 2
# Apply dynamics curve
final_prob = (combined_prob + noise) ** 1.5
final_prob = self.math_utils.apply_dynamics_curve(final_prob, self.config.dynamics_curve)
return float(np.clip(final_prob, 0, 1))
def should_play_note(
self,
layer_idx: int,
token_idx: int,
attention_val: float,
hidden_state: np.ndarray,
num_tokens: int,
role: MusicRole
) -> bool:
"""Determine if a note should be played."""
prob = self.create_note_probability(
layer_idx, token_idx, attention_val, hidden_state, num_tokens, role
)
# Adjust probability based on silence duration
if layer_idx in self.history:
last_played = self.history[layer_idx]
silence_duration = token_idx - last_played
prob *= (1 + np.tanh(silence_duration / 5) * 0.5)
# Stochastic decision
play_note = np.random.random() < prob
if play_note:
self.history[layer_idx] = token_idx
return play_note
def generate_notes_for_role(
self,
role: MusicRole,
hidden_state: np.ndarray,
scale: np.ndarray
) -> List[int]:
"""Generate notes based on role and hidden state."""
if role == MusicRole.MELODY:
note = self.math_utils.norm_to_scale(
hidden_state[0], scale, octave_range=1
)
return [note]
elif role == MusicRole.BASS:
note = self.math_utils.norm_to_scale(
hidden_state[0], scale, octave_range=0
) - 12
return [note]
elif role == MusicRole.HARMONY:
return [
self.math_utils.norm_to_scale(hidden_state[i], scale, octave_range=1)
for i in range(0, min(2, len(hidden_state)), 1)
]
elif role == MusicRole.PAD:
return [
self.math_utils.norm_to_scale(hidden_state[i], scale, octave_range=1)
for i in range(0, min(3, len(hidden_state)), 2)
]
elif role == MusicRole.ACCENT:
note = self.math_utils.norm_to_scale(
hidden_state[0], scale, octave_range=2
) + 12
return [note]
else: # ATMOSPHERE
return [
self.math_utils.norm_to_scale(hidden_state[i], scale, octave_range=1)
for i in range(0, min(2, len(hidden_state)), 3)
]
def calculate_velocity(
self,
role: MusicRole,
attention_strength: float
) -> int:
"""Calculate note velocity based on role and attention."""
base_velocity = int(
attention_strength * (self.config.velocity_range[1] - self.config.velocity_range[0])
+ self.config.velocity_range[0]
)
# Role-specific adjustments
if role == MusicRole.MELODY:
velocity = min(base_velocity + 10, 127)
elif role == MusicRole.ACCENT:
velocity = min(base_velocity + 20, 127)
elif role in [MusicRole.PAD, MusicRole.ATMOSPHERE]:
velocity = max(base_velocity - 10, 20)
else:
velocity = base_velocity
return velocity
def calculate_duration(
self,
role: MusicRole,
attention_matrix: np.ndarray
) -> int:
"""Calculate note duration based on role and attention."""
if role in [MusicRole.PAD, MusicRole.ATMOSPHERE]:
duration = self.config.base_tempo * 4
elif role == MusicRole.BASS:
duration = self.config.base_tempo
else:
try:
dur_factor = self.math_utils.entropy(attention_matrix.mean(axis=0)) / (
np.log2(attention_matrix.shape[-1]) + 1e-9
)
except Exception:
dur_factor = 0.5
duration = self.math_utils.quantize_time(
int(self.config.base_tempo * (0.5 + dur_factor * 1.5)),
self.config.quantization_grid
)
return duration
# Model Interaction
class LatentExtractor(ABC):
"""Abstract base class for latent extraction strategies."""
@abstractmethod
def extract(self, text: str, config: GenerationConfig, progress=None) -> Latents:
"""Extract latents from text."""
pass
class MockLatentExtractor(LatentExtractor):
"""Generate mock latents for testing without loading models."""
def extract(self, text: str, config: GenerationConfig, progress=None) -> Latents:
"""Generate synthetic latents based on text."""
# Simulate token count based on text length
tokens = max(16, min(128, len(text.split()) * 4))
layers = min(config.num_layers_limit, 6)
# Generate deterministic but varied latents based on text
np.random.seed(hash(text) % 2**32)
hidden_states = [
torch.randn(1, tokens, 128) for _ in range(layers)
]
attentions = [
torch.rand(1, 8, tokens, tokens) for _ in range(layers)
]
metadata = {
"mode": "mock",
"text_length": len(text),
"generated_tokens": tokens,
"generated_layers": layers
}
return Latents(
hidden_states=hidden_states,
attentions=attentions,
num_layers=layers,
num_tokens=tokens,
metadata=metadata
)
class ModelLatentExtractor(LatentExtractor):
"""Extract real latents from transformer models."""
@spaces.GPU(duration=45)
def extract(self, text: str, config: GenerationConfig, progress=None) -> Latents:
"""Extract latents from a real transformer model."""
model_name = config.model_name
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
# Configure model loading
load_kwargs = {
"output_hidden_states": True,
"output_attentions": True,
"device_map": "cuda" if torch.cuda.is_available() else "cpu",
}
# Set appropriate dtype
try:
load_kwargs["torch_dtype"] = (
torch.bfloat16 if torch.cuda.is_available() else torch.float32
)
except Exception:
pass
# Load model
model = AutoModel.from_pretrained(model_name, **load_kwargs)
# Tokenize input
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Get model outputs
with torch.no_grad():
outputs = model(**inputs)
hidden_states = list(outputs.hidden_states)
attentions = list(outputs.attentions)
# Move to CPU to free VRAM
hidden_states = [hs.to("cpu") for hs in hidden_states]
attentions = [att.to("cpu") for att in attentions]
# Limit layers
layers = min(config.num_layers_limit, len(hidden_states))
tokens = hidden_states[0].shape[1]
# Clean up
try:
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
except Exception:
pass
metadata = {
"mode": "full_model",
"model_name": model_name,
"actual_layers": len(hidden_states),
"used_layers": layers,
"tokens": tokens
}
return Latents(
hidden_states=hidden_states[:layers],
attentions=attentions[:layers],
num_layers=layers,
num_tokens=tokens,
metadata=metadata
)
class LatentExtractorFactory:
"""Factory for creating appropriate latent extractors."""
@staticmethod
def create(compute_mode: ComputeMode) -> LatentExtractor:
"""Create a latent extractor based on compute mode."""
if compute_mode == ComputeMode.MOCK_LATENTS:
return MockLatentExtractor()
else:
return ModelLatentExtractor()
# MIDI Generation
class MIDIRenderer:
"""Renders MIDI files from latents."""
def __init__(self, config: GenerationConfig, instrument_manager: InstrumentPresetManager):
"""Initialize MIDI renderer."""
self.config = config
self.instrument_manager = instrument_manager
self.note_generator = NoteGenerator(config)
self.math_utils = MusicMathUtils()
def render(self, latents: Latents) -> Tuple[bytes, Dict[str, Any]]:
"""Render MIDI from latents."""
# Set random seeds for reproducibility
np.random.seed(self.config.seed)
random.seed(self.config.seed)
torch.manual_seed(self.config.seed)
# Prepare data
scale = np.array(self.config.scale.notes, dtype=int)
num_layers = latents.num_layers
num_tokens = latents.num_tokens
# Convert tensors to numpy
hidden_states = [
hs.float().numpy() if isinstance(hs, torch.Tensor) else hs
for hs in latents.hidden_states
]
attentions = [
att.float().numpy() if isinstance(att, torch.Tensor) else att
for att in latents.attentions
]
# Get instrument mappings
instrument_mappings = self.instrument_manager.get_preset(self.config.instrument_preset)
# Create MIDI file and tracks
midi_file = MidiFile()
tracks = self._create_tracks(midi_file, num_layers, instrument_mappings)
# Generate notes
stats = self._generate_notes(
tracks, hidden_states, attentions,
scale, num_tokens, instrument_mappings
)
# Convert to bytes
bio = io.BytesIO()
midi_file.save(file=bio)
bio.seek(0)
# Prepare metadata
metadata = {
"config": self.config.to_dict(),
"latents_info": latents.metadata,
"stats": stats,
"timestamp": time.time()
}
return bio.read(), metadata
def _create_tracks(
self,
midi_file: MidiFile,
num_layers: int,
instrument_mappings: List[InstrumentMapping]
) -> List[MidiTrack]:
"""Create MIDI tracks with instrument assignments."""
tracks = []
for layer_idx in range(num_layers):
track = MidiTrack()
midi_file.tracks.append(track)
tracks.append(track)
# Get instrument mapping for this layer
if layer_idx < len(instrument_mappings):
mapping = instrument_mappings[layer_idx]
else:
# Default to piano if not enough mappings
mapping = InstrumentMapping(0, MusicRole.MELODY, layer_idx % 16)
# Set instrument
track.append(Message(
"program_change",
program=mapping.program,
time=0,
channel=mapping.channel
))
# Add track name
if mapping.name:
track.append(mido.MetaMessage(
"track_name",
name=f"{mapping.name} - {mapping.role.value}",
time=0
))
return tracks
def _generate_notes(
self,
tracks: List[MidiTrack],
hidden_states: List[np.ndarray],
attentions: List[np.ndarray],
scale: np.ndarray,
num_tokens: int,
instrument_mappings: List[InstrumentMapping]
) -> Dict[str, Any]:
"""Generate notes for all tracks."""
current_time = [0] * len(tracks)
notes_count = [0] * len(tracks)
for token_idx in range(num_tokens):
# Update time periodically
if token_idx > 0 and token_idx % 4 == 0:
for layer_idx in range(len(tracks)):
current_time[layer_idx] += self.config.base_tempo
# Calculate panning
pan = 64 + int(32 * np.sin(token_idx * math.pi / max(1, num_tokens)))
# Generate notes for each layer
for layer_idx in range(len(tracks)):
if layer_idx >= len(instrument_mappings):
continue
mapping = instrument_mappings[layer_idx]
# Get attention and hidden state
attn_matrix = attentions[min(layer_idx, len(attentions) - 1)][0, :, token_idx, :]
attention_strength = float(np.mean(attn_matrix))
layer_vec = hidden_states[layer_idx][0, token_idx]
# Check if note should be played
if not self.note_generator.should_play_note(
layer_idx, token_idx, attention_strength,
layer_vec, num_tokens, mapping.role
):
continue
# Generate notes
notes_to_play = self.note_generator.generate_notes_for_role(
mapping.role, layer_vec, scale
)
# Calculate velocity and duration
velocity = self.note_generator.calculate_velocity(
mapping.role, attention_strength
)
duration = self.note_generator.calculate_duration(
mapping.role, attn_matrix
)
# Add notes to track
for note in notes_to_play:
note = max(21, min(108, int(note))) # Clamp to piano range
tracks[layer_idx].append(Message(
"note_on",
note=note,
velocity=velocity,
time=current_time[layer_idx],
channel=mapping.channel
))
tracks[layer_idx].append(Message(
"note_off",
note=note,
velocity=0,
time=duration,
channel=mapping.channel
))
current_time[layer_idx] = 0
notes_count[layer_idx] += 1
# Set panning on first token
if token_idx == 0:
tracks[layer_idx].append(Message(
"control_change",
control=10,
value=pan,
time=0,
channel=mapping.channel
))
return {
"num_layers": len(tracks),
"num_tokens": num_tokens,
"notes_per_layer": notes_count,
"total_notes": int(sum(notes_count)),
"tempo_ticks_per_beat": int(self.config.base_tempo),
"scale": list(map(int, scale.tolist())),
}
# Main Orchestrator
class LLMForestOrchestra:
"""Main orchestrator class that coordinates the entire pipeline."""
DEFAULT_MODEL = "unsloth/Qwen3-14B-Base"
def __init__(self):
"""Initialize the orchestra."""
self.scale_manager = ScaleManager()
self.instrument_manager = InstrumentPresetManager()
self.saved_configs: Dict[str, GenerationConfig] = {}
def generate(
self,
text: str,
model_name: str,
compute_mode: str,
base_tempo: int,
velocity_range: Tuple[int, int],
scale_name: str,
custom_scale_notes: Optional[List[int]],
num_layers: int,
instrument_preset: str,
seed: int,
quantization_grid: int = 120,
octave_range: int = 2,
dynamics_curve: str = "linear"
) -> Tuple[str, Dict[str, Any]]:
"""Generate MIDI from text input."""
# Get or create scale
if scale_name == "Custom":
if not custom_scale_notes:
raise ValueError("Custom scale requires note list")
scale = ScaleDefinition("Custom", custom_scale_notes)
else:
scale = self.scale_manager.get_scale(scale_name)
if scale is None:
raise ValueError(f"Unknown scale: {scale_name}")
# Create configuration
config = GenerationConfig(
model_name=model_name or self.DEFAULT_MODEL,
compute_mode=ComputeMode(compute_mode),
base_tempo=base_tempo,
velocity_range=velocity_range,
scale=scale,
num_layers_limit=num_layers,
seed=seed,
instrument_preset=instrument_preset,
quantization_grid=quantization_grid,
octave_range=octave_range,
dynamics_curve=dynamics_curve
)
# Validate configuration
config.validate()
# Extract latents
extractor = LatentExtractorFactory.create(config.compute_mode)
latents = extractor.extract(text, config)
# Render MIDI
renderer = MIDIRenderer(config, self.instrument_manager)
midi_bytes, metadata = renderer.render(latents)
# Save MIDI file
filename = f"llm_forest_orchestra_{uuid.uuid4().hex[:8]}.mid"
with open(filename, "wb") as f:
f.write(midi_bytes)
return filename, metadata
def save_config(self, name: str, config: GenerationConfig):
"""Save a configuration for later use."""
self.saved_configs[name] = config
def load_config(self, name: str) -> Optional[GenerationConfig]:
"""Load a saved configuration."""
return self.saved_configs.get(name)
def export_config(self, config: GenerationConfig, filepath: str):
"""Export configuration to JSON file."""
with open(filepath, "w") as f:
json.dump(config.to_dict(), f, indent=2)
def import_config(self, filepath: str) -> GenerationConfig:
"""Import configuration from JSON file."""
with open(filepath, "r") as f:
data = json.load(f)
return GenerationConfig.from_dict(data, self.scale_manager)
# Gradio UI
class GradioInterface:
"""Manages the Gradio user interface."""
DESCRIPTION = """
# 🌲 LLM Forest Orchestra β€” Sonify Transformer Internals
Transform the hidden states and attention patterns of language models into multi-layered musical compositions.
## πŸ„ Inspiration
This project is inspired by the way **mushrooms and mycelial networks in forests**
connect plants and trees, forming a living web of communication and resource sharing.
These connections, can be turned into ethereal music.
Just as signals move through these hidden connections, transformer models also
pass hidden states and attentions across their layers. Here, those hidden
connections are translated into **music**, analogous to the forest's secret orchestra.
## Features
- **Two compute modes**: Full model (GPU) or Mock latents (CPU-friendly)
- **Multiple musical scales**: From pentatonic to chromatic
- **Instrument presets**: Orchestral, electronic, ensemble, and more
- **Advanced controls**: Dynamics curves, quantization, velocity ranges
- **Export**: Standard MIDI files for further editing in your DAW
"""
EXAMPLE_TEXT = """Joy cascades in golden waterfalls, crashing into pools of melancholy blue.
Anger burns red through veins of marble, while serenity floats on clouds of softest grey.
Love pulses in waves of crimson and rose, intertwining with longing's purple haze.
Each feeling resonates at its own frequency, painting music across the soul's canvas."""
def __init__(self, orchestra: LLMForestOrchestra):
"""Initialize the interface."""
self.orchestra = orchestra
def create_interface(self) -> gr.Blocks:
"""Create the Gradio interface."""
with gr.Blocks(title="LLM Forest Orchestra", theme=gr.themes.Soft()) as demo:
gr.Markdown(self.DESCRIPTION)
with gr.Tabs():
with gr.TabItem("🎡 Generate Music"):
self._create_generation_tab()
return demo
def _create_generation_tab(self):
"""Create the main generation tab."""
with gr.Row():
with gr.Column(scale=1):
text_input = gr.Textbox(
value=self.EXAMPLE_TEXT,
label="Input Text",
lines=8,
placeholder="Enter text to sonify..."
)
model_name = gr.Textbox(
value=self.orchestra.DEFAULT_MODEL,
label="Hugging Face Model",
info="Model must support output_hidden_states and output_attentions"
)
compute_mode = gr.Radio(
choices=["Full model", "Mock latents"],
value="Mock latents",
label="Compute Mode",
info="Mock latents for quick CPU-only demo"
)
with gr.Row():
instrument_preset = gr.Dropdown(
choices=self.orchestra.instrument_manager.list_presets(),
value="Ensemble (melody+bass+pad etc.)",
label="Instrument Preset"
)
scale_choice = gr.Dropdown(
choices=self.orchestra.scale_manager.list_scales() + ["Custom"],
value="C pentatonic",
label="Musical Scale"
)
custom_scale = gr.Textbox(
value="",
label="Custom Scale Notes",
placeholder="60,62,65,67,70",
visible=False
)
with gr.Row():
base_tempo = gr.Slider(
120, 960,
value=480,
step=1,
label="Tempo (ticks per beat)"
)
num_layers = gr.Slider(
1, 6,
value=6,
step=1,
label="Max Layers"
)
with gr.Row():
velocity_low = gr.Slider(
1, 126,
value=40,
step=1,
label="Min Velocity"
)
velocity_high = gr.Slider(
2, 127,
value=90,
step=1,
label="Max Velocity"
)
seed = gr.Number(
value=42,
precision=0,
label="Random Seed"
)
generate_btn = gr.Button(
"🎼 Generate MIDI",
variant="primary",
size="lg"
)
with gr.Column(scale=1):
midi_output = gr.File(
label="Generated MIDI File",
file_types=[".mid", ".midi"]
)
stats_display = gr.Markdown(label="Quick Stats")
metadata_json = gr.Code(
label="Metadata (JSON)",
language="json"
)
with gr.Row():
play_instructions = gr.Markdown(
"""
### 🎧 How to Play
1. Download the MIDI file
2. Open in any DAW or MIDI player
3. Adjust instruments and effects as desired
4. Export to audio format
"""
)
# Set up interactions
def update_custom_scale_visibility(choice):
return gr.update(visible=(choice == "Custom"))
scale_choice.change(
update_custom_scale_visibility,
inputs=[scale_choice],
outputs=[custom_scale]
)
def generate_wrapper(
text, model_name, compute_mode, base_tempo,
velocity_low, velocity_high, scale_choice,
custom_scale, num_layers, instrument_preset, seed
):
"""Wrapper for generation with error handling."""
try:
# Parse custom scale if needed
custom_notes = None
if scale_choice == "Custom" and custom_scale:
custom_notes = [int(x.strip()) for x in custom_scale.split(",")]
# Generate
filename, metadata = self.orchestra.generate(
text=text,
model_name=model_name,
compute_mode=compute_mode,
base_tempo=int(base_tempo),
velocity_range=(int(velocity_low), int(velocity_high)),
scale_name=scale_choice,
custom_scale_notes=custom_notes,
num_layers=int(num_layers),
instrument_preset=instrument_preset,
seed=int(seed)
)
# Format stats
stats = metadata.get("stats", {})
stats_text = f"""
### Generation Statistics
- **Layers Used**: {stats.get('num_layers', 'N/A')}
- **Tokens Processed**: {stats.get('num_tokens', 'N/A')}
- **Total Notes**: {stats.get('total_notes', 'N/A')}
- **Notes per Layer**: {stats.get('notes_per_layer', [])}
- **Scale**: {stats.get('scale', [])}
- **Tempo**: {stats.get('tempo_ticks_per_beat', 'N/A')} ticks/beat
"""
return filename, stats_text, json.dumps(metadata, indent=2)
except Exception as e:
error_msg = f"### ❌ Error\n{str(e)}"
return None, error_msg, json.dumps({"error": str(e)}, indent=2)
generate_btn.click(
fn=generate_wrapper,
inputs=[
text_input, model_name, compute_mode, base_tempo,
velocity_low, velocity_high, scale_choice,
custom_scale, num_layers, instrument_preset, seed
],
outputs=[midi_output, stats_display, metadata_json]
)
# Main Entry Point
def main():
"""Main entry point for the application."""
# Initialize orchestra
orchestra = LLMForestOrchestra()
# Create interface
interface = GradioInterface(orchestra)
demo = interface.create_interface()
# Launch
demo.launch()
if __name__ == "__main__":
main()