Spaces:
Running
on
Zero
Running
on
Zero
Big refactor + more features
Browse files
app.py
CHANGED
@@ -4,11 +4,13 @@ import gc
|
|
4 |
import math
|
5 |
import time
|
6 |
import uuid
|
|
|
7 |
import spaces
|
8 |
import random
|
9 |
-
from
|
10 |
-
from
|
11 |
-
import
|
|
|
12 |
|
13 |
import gradio as gr
|
14 |
import numpy as np
|
@@ -18,386 +20,1142 @@ from transformers import AutoModel, AutoTokenizer
|
|
18 |
import mido
|
19 |
from mido import Message, MidiFile, MidiTrack
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
"
|
27 |
-
"
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
@dataclass
|
60 |
-
class
|
|
|
61 |
model_name: str
|
62 |
-
compute_mode:
|
63 |
base_tempo: int
|
64 |
velocity_range: Tuple[int, int]
|
65 |
-
scale:
|
66 |
num_layers_limit: int
|
67 |
seed: int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
-
# --- Core math helpers ---
|
70 |
-
|
71 |
-
def entropy(p: np.ndarray) -> float:
|
72 |
-
p = p / (p.sum() + 1e-9)
|
73 |
-
return float(-np.sum(p * np.log2(p + 1e-9)))
|
74 |
-
|
75 |
-
def quantize_time(time_val: int, grid: int = 120) -> int:
|
76 |
-
return int(round(time_val / grid) * grid)
|
77 |
-
|
78 |
-
def norm_to_scale(val: float, scale: np.ndarray, octave_range: int = 2) -> int:
|
79 |
-
octave = int(abs(val) * octave_range) * 12
|
80 |
-
note_idx = int(abs(val * 100) % len(scale))
|
81 |
-
return int(scale[note_idx] + octave)
|
82 |
-
|
83 |
-
ROLE_FREQS = {
|
84 |
-
'melody': 2.0,
|
85 |
-
'bass': 0.5,
|
86 |
-
'harmony': 1.5,
|
87 |
-
'pad': 0.25,
|
88 |
-
'accent': 3.0,
|
89 |
-
'atmosphere': 0.33
|
90 |
-
}
|
91 |
-
|
92 |
-
ROLE_WEIGHTS = {
|
93 |
-
'melody': np.array([0.4, 0.2, 0.2, 0.1, 0.1]),
|
94 |
-
'bass': np.array([0.1, 0.4, 0.1, 0.3, 0.1]),
|
95 |
-
'harmony': np.array([0.2, 0.2, 0.3, 0.2, 0.1]),
|
96 |
-
'pad': np.array([0.1, 0.3, 0.1, 0.1, 0.4]),
|
97 |
-
'accent': np.array([0.5, 0.1, 0.2, 0.1, 0.1]),
|
98 |
-
'atmosphere': np.array([0.1, 0.2, 0.1, 0.2, 0.4])
|
99 |
-
}
|
100 |
-
|
101 |
-
def create_note_probability(layer_idx, token_idx, attention_val, hidden_state, num_tokens, role: str):
|
102 |
-
base_prob = 1 / (1 + np.exp(-10 * (attention_val - 0.5)))
|
103 |
-
temporal_factor = 0.5 + 0.5 * np.sin(2 * np.pi * ROLE_FREQS[role] * token_idx / max(1, num_tokens))
|
104 |
-
energy = np.linalg.norm(hidden_state)
|
105 |
-
energy_factor = np.tanh(energy / 10)
|
106 |
-
local_variance = np.var(hidden_state)
|
107 |
-
variance_factor = 1 - np.exp(-local_variance)
|
108 |
-
state_entropy = entropy(np.abs(hidden_state))
|
109 |
-
max_entropy = np.log2(max(2, hidden_state.shape[0]))
|
110 |
-
entropy_factor = state_entropy / max_entropy
|
111 |
-
factors = np.array([base_prob, temporal_factor, energy_factor, variance_factor, entropy_factor])
|
112 |
-
weights = ROLE_WEIGHTS[role]
|
113 |
-
combined_prob = float(np.dot(weights, factors))
|
114 |
-
noise_seed = layer_idx * 1000 + token_idx
|
115 |
-
noise = 0.1 * (np.sin(noise_seed * 0.1) + np.cos(noise_seed * 0.23)) / 2
|
116 |
-
final_prob = (combined_prob + noise) ** 1.5
|
117 |
-
return float(np.clip(final_prob, 0, 1))
|
118 |
-
|
119 |
-
def should_play_note_stochastic(layer_idx, token_idx, attention_val, hidden_state, num_tokens, role: str, history: Dict[int,int]):
|
120 |
-
prob = create_note_probability(layer_idx, token_idx, attention_val, hidden_state, num_tokens, role)
|
121 |
-
if layer_idx in history:
|
122 |
-
last_played = history[layer_idx]
|
123 |
-
silence_duration = token_idx - last_played
|
124 |
-
prob *= (1 + np.tanh(silence_duration / 5) * 0.5)
|
125 |
-
play_note = np.random.random() < prob
|
126 |
-
if play_note:
|
127 |
-
history[layer_idx] = token_idx
|
128 |
-
return play_note
|
129 |
-
|
130 |
-
# -------------------- Model / Latents --------------------
|
131 |
|
132 |
@dataclass
|
133 |
class Latents:
|
|
|
134 |
hidden_states: List[torch.Tensor]
|
135 |
attentions: List[torch.Tensor]
|
136 |
num_layers: int
|
137 |
num_tokens: int
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
pass
|
193 |
|
194 |
-
return Latents(hidden_states=hidden_states[:layers], attentions=attentions[:layers], num_layers=layers, num_tokens=tokens)
|
195 |
-
|
196 |
-
# -------------------- MIDI Rendering --------------------
|
197 |
-
|
198 |
-
def render_midi(latents: Latents, scale_notes: List[int], base_tempo: int, velocity_range: Tuple[int, int], preset_name: str, seed: int) -> Tuple[bytes, Dict]:
|
199 |
-
np.random.seed(seed)
|
200 |
-
random.seed(seed)
|
201 |
-
|
202 |
-
scale = np.array(scale_notes, dtype=int)
|
203 |
-
num_layers = latents.num_layers
|
204 |
-
num_tokens = latents.num_tokens
|
205 |
-
hidden_states = [hs.float().numpy() if isinstance(hs, torch.Tensor) else hs for hs in latents.hidden_states]
|
206 |
-
attentions = [att.float().numpy() if isinstance(att, torch.Tensor) else att for att in latents.attentions]
|
207 |
-
|
208 |
-
layer_instruments = LAYER_INSTRUMENT_PRESETS[preset_name]
|
209 |
-
|
210 |
-
mid = MidiFile()
|
211 |
-
tracks: List[MidiTrack] = []
|
212 |
-
for ch in range(num_layers):
|
213 |
-
track = MidiTrack()
|
214 |
-
mid.tracks.append(track)
|
215 |
-
tracks.append(track)
|
216 |
-
instrument = layer_instruments.get(ch, (0, 'melody'))[0]
|
217 |
-
track.append(Message('program_change', program=int(instrument), time=0, channel=ch))
|
218 |
-
|
219 |
-
history: Dict[int, int] = {}
|
220 |
-
current_time = [0] * num_layers
|
221 |
-
notes_count = [0] * num_layers
|
222 |
-
|
223 |
-
for token_idx in range(num_tokens):
|
224 |
-
if token_idx > 0 and token_idx % 4 == 0:
|
225 |
-
for layer_idx in range(num_layers):
|
226 |
-
current_time[layer_idx] += base_tempo
|
227 |
-
|
228 |
-
pan = 64 + int(32 * np.sin(token_idx * math.pi / max(1, num_tokens)))
|
229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
for layer_idx in range(num_layers):
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
continue
|
239 |
-
|
240 |
-
if role == 'melody':
|
241 |
-
note = norm_to_scale(layer_vec[0], scale, octave_range=1)
|
242 |
-
notes_to_play = [note]
|
243 |
-
elif role == 'bass':
|
244 |
-
note = norm_to_scale(layer_vec[0], scale, octave_range=0) - 12
|
245 |
-
notes_to_play = [note]
|
246 |
-
elif role == 'harmony':
|
247 |
-
notes_to_play = [norm_to_scale(layer_vec[i], scale, octave_range=1) for i in range(0, min(2, len(layer_vec)), 1)]
|
248 |
-
elif role == 'pad':
|
249 |
-
notes_to_play = [norm_to_scale(layer_vec[i], scale, octave_range=1) for i in range(0, min(3, len(layer_vec)), 2)]
|
250 |
-
elif role == 'accent':
|
251 |
-
note = norm_to_scale(layer_vec[0], scale, octave_range=2) + 12
|
252 |
-
notes_to_play = [note]
|
253 |
-
else:
|
254 |
-
notes_to_play = [norm_to_scale(layer_vec[i], scale, octave_range=1) for i in range(0, min(2, len(layer_vec)), 3)]
|
255 |
-
|
256 |
-
base_velocity = int(attention_strength * (velocity_range[1] - velocity_range[0]) + velocity_range[0])
|
257 |
-
if role == 'melody':
|
258 |
-
velocity = min(base_velocity + 10, 127)
|
259 |
-
elif role == 'bass':
|
260 |
-
velocity = base_velocity
|
261 |
-
elif role == 'accent':
|
262 |
-
velocity = min(base_velocity + 20, 127)
|
263 |
else:
|
264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
|
266 |
-
if role in ['pad', 'atmosphere']:
|
267 |
-
duration = base_tempo * 4
|
268 |
-
elif role == 'bass':
|
269 |
-
duration = base_tempo
|
270 |
-
else:
|
271 |
-
try:
|
272 |
-
dur_factor = entropy(attn_matrix.mean(axis=0)) / (np.log2(attn_matrix.shape[-1]) + 1e-9)
|
273 |
-
except Exception:
|
274 |
-
dur_factor = 0.5
|
275 |
-
duration = quantize_time(int(base_tempo * (0.5 + dur_factor * 1.5)))
|
276 |
-
|
277 |
-
for note in notes_to_play:
|
278 |
-
note = max(21, min(108, int(note)))
|
279 |
-
tracks[layer_idx].append(Message('note_on', note=note, velocity=velocity, time=current_time[layer_idx], channel=layer_idx))
|
280 |
-
tracks[layer_idx].append(Message('note_off', note=note, velocity=0, time=duration, channel=layer_idx))
|
281 |
-
current_time[layer_idx] = 0
|
282 |
-
notes_count[layer_idx] += 1
|
283 |
-
|
284 |
-
if token_idx == 0:
|
285 |
-
tracks[layer_idx].append(Message('control_change', control=10, value=pan, time=0, channel=layer_idx))
|
286 |
-
|
287 |
-
# Save to bytes
|
288 |
-
bio = io.BytesIO()
|
289 |
-
mid.save(file=bio)
|
290 |
-
bio.seek(0)
|
291 |
-
|
292 |
-
meta = {
|
293 |
-
"num_layers": num_layers,
|
294 |
-
"num_tokens": num_tokens,
|
295 |
-
"notes_per_layer": notes_count,
|
296 |
-
"total_notes": int(sum(notes_count)),
|
297 |
-
"tempo_ticks_per_beat": int(base_tempo),
|
298 |
-
"scale": list(map(int, scale.tolist())),
|
299 |
-
}
|
300 |
-
return bio.read(), meta
|
301 |
-
|
302 |
-
# -------------------- Gradio UI --------------------
|
303 |
-
|
304 |
-
DESCRIPTION = """
|
305 |
-
# LLM Forest Orchestra — Sonify Transformer Internals
|
306 |
-
Turn hidden states and attentions into a multi-track MIDI composition.
|
307 |
-
|
308 |
-
- **Two compute modes**: *Full model* (loads a HF model and extracts latents) or *Mock latents* (quick demo with synthetic tensors — great for CPU-only Spaces).
|
309 |
-
- Choose **scale**, **tempo**, **velocity range**, and **instrument/role preset**.
|
310 |
-
- Exports a **MIDI** you can arrange further in your DAW.
|
311 |
-
|
312 |
-
|
313 |
-
## Inspiration
|
314 |
-
|
315 |
-
This project is inspired by the way **mushrooms and mycelial networks in forests**
|
316 |
-
connect plants and trees, forming a living web of communication and resource sharing.
|
317 |
-
These connections, can be turned into ethereal music.
|
318 |
-
Just as signals move through these hidden connections, transformer models also
|
319 |
-
pass hidden states and attentions across their layers. Here, those hidden
|
320 |
-
connections are translated into **music**, analogous to the forest's secret orchestra.
|
321 |
-
"""
|
322 |
-
|
323 |
-
EXAMPLE_TEXT = """Joy cascades in golden waterfalls, crashing into pools of melancholy blue.
|
324 |
-
Anger burns red through veins of marble, while serenity floats on clouds of softest grey.
|
325 |
-
Love pulses in waves of crimson and rose, intertwining with longing's purple haze.
|
326 |
-
Each feeling resonates at its own frequency, painting music across the soul's canvas.
|
327 |
-
"""
|
328 |
-
|
329 |
-
def parse_scale(selection: str, custom: str) -> List[int]:
|
330 |
-
if selection == "Custom (comma-separated MIDI notes)":
|
331 |
-
try:
|
332 |
-
return [int(x.strip()) for x in custom.split(",") if x.strip()]
|
333 |
-
except Exception:
|
334 |
-
return SCALES["C pentatonic"]
|
335 |
-
return SCALES[selection] if SCALES[selection] else SCALES["C pentatonic"]
|
336 |
-
|
337 |
-
def generate(text, model_name, compute_mode, base_tempo, velocity_low, velocity_high, scale_choice, custom_scale, num_layers, preset, seed):
|
338 |
-
scale = parse_scale(scale_choice, custom_scale)
|
339 |
-
cfg = GenConfig(
|
340 |
-
model_name=model_name or DEFAULT_MODEL,
|
341 |
-
compute_mode=compute_mode,
|
342 |
-
base_tempo=int(base_tempo),
|
343 |
-
velocity_range=(int(velocity_low), int(velocity_high)),
|
344 |
-
scale=scale,
|
345 |
-
num_layers_limit=int(num_layers),
|
346 |
-
seed=int(seed),
|
347 |
-
)
|
348 |
-
|
349 |
-
# Get latents
|
350 |
-
latents = get_latents(text, cfg.model_name, cfg.compute_mode, cfg.num_layers_limit)
|
351 |
-
|
352 |
-
# Render MIDI
|
353 |
-
midi_bytes, meta = render_midi(latents, cfg.scale, cfg.base_tempo, cfg.velocity_range, preset, cfg.seed)
|
354 |
-
|
355 |
-
# Persist to a file for download
|
356 |
-
out_name = f"llm_forest_orchestra_{uuid.uuid4().hex[:8]}.mid"
|
357 |
-
with open(out_name, "wb") as f:
|
358 |
-
f.write(midi_bytes)
|
359 |
-
|
360 |
-
# Prepare quick stats
|
361 |
-
stats = (
|
362 |
-
f"Layers: {meta['num_layers']} | Tokens: {meta['num_tokens']} | "
|
363 |
-
f"Total notes: {meta['total_notes']} | Scale: {meta['scale']} | "
|
364 |
-
f"Tempo (ticks/beat): {meta['tempo_ticks_per_beat']}"
|
365 |
-
)
|
366 |
-
|
367 |
-
return out_name, stats, json.dumps(meta, indent=2)
|
368 |
-
|
369 |
-
with gr.Blocks(title="LLM Forest Orchestra — MIDI from Transformer Internals") as demo:
|
370 |
-
gr.Markdown(DESCRIPTION)
|
371 |
-
|
372 |
-
with gr.Row():
|
373 |
-
with gr.Column():
|
374 |
-
text = gr.Textbox(value=EXAMPLE_TEXT, label="Input text", lines=8)
|
375 |
-
model_name = gr.Textbox(value=DEFAULT_MODEL, label="HF model (base) to probe", info="Should support output_hidden_states & output_attentions")
|
376 |
-
compute_mode = gr.Radio(choices=["Mock latents", "Full model"], value="Full model", label="Compute mode")
|
377 |
-
preset = gr.Dropdown(choices=list(LAYER_INSTRUMENT_PRESETS.keys()), value="Ensemble (melody+bass+pad etc.)", label="Instrument/Role preset")
|
378 |
-
with gr.Row():
|
379 |
-
base_tempo = gr.Slider(120, 960, value=480, step=1, label="Ticks per beat (tempo grid)")
|
380 |
-
num_layers = gr.Slider(1, 6, value=6, step=1, label="Max layers to use")
|
381 |
-
with gr.Row():
|
382 |
-
velocity_low = gr.Slider(1, 126, value=40, step=1, label="Velocity min")
|
383 |
-
velocity_high = gr.Slider(2, 127, value=90, step=1, label="Velocity max")
|
384 |
-
with gr.Row():
|
385 |
-
scale_choice = gr.Dropdown(choices=list(SCALES.keys()), value="C pentatonic", label="Scale")
|
386 |
-
custom_scale = gr.Textbox(value="", label="Custom scale notes (e.g. 60,62,65,67)")
|
387 |
-
seed = gr.Number(value=42, precision=0, label="Random seed")
|
388 |
-
|
389 |
-
btn = gr.Button("Generate MIDI", variant="primary")
|
390 |
-
|
391 |
-
with gr.Column():
|
392 |
-
midi_file = gr.File(label="MIDI output (.mid)")
|
393 |
-
stats = gr.Markdown("")
|
394 |
-
meta_json = gr.Code(label="Meta (JSON)")
|
395 |
-
|
396 |
-
btn.click(
|
397 |
-
fn=generate,
|
398 |
-
inputs=[text, model_name, compute_mode, base_tempo, velocity_low, velocity_high, scale_choice, custom_scale, num_layers, preset, seed],
|
399 |
-
outputs=[midi_file, stats, meta_json]
|
400 |
-
)
|
401 |
|
402 |
if __name__ == "__main__":
|
403 |
-
|
|
|
4 |
import math
|
5 |
import time
|
6 |
import uuid
|
7 |
+
import json
|
8 |
import spaces
|
9 |
import random
|
10 |
+
from abc import ABC, abstractmethod
|
11 |
+
from dataclasses import dataclass, field, asdict
|
12 |
+
from typing import Dict, List, Tuple, Optional, Any, Union
|
13 |
+
from enum import Enum
|
14 |
|
15 |
import gradio as gr
|
16 |
import numpy as np
|
|
|
20 |
import mido
|
21 |
from mido import Message, MidiFile, MidiTrack
|
22 |
|
23 |
+
|
24 |
+
# Configuration Classes
|
25 |
+
|
26 |
+
class ComputeMode(Enum):
|
27 |
+
"""Enum for computation modes."""
|
28 |
+
FULL_MODEL = "Full model"
|
29 |
+
MOCK_LATENTS = "Mock latents"
|
30 |
+
|
31 |
+
|
32 |
+
class MusicRole(Enum):
|
33 |
+
"""Enum for musical roles/layers."""
|
34 |
+
MELODY = "melody"
|
35 |
+
BASS = "bass"
|
36 |
+
HARMONY = "harmony"
|
37 |
+
PAD = "pad"
|
38 |
+
ACCENT = "accent"
|
39 |
+
ATMOSPHERE = "atmosphere"
|
40 |
+
|
41 |
+
|
42 |
+
@dataclass
|
43 |
+
class ScaleDefinition:
|
44 |
+
"""Represents a musical scale."""
|
45 |
+
name: str
|
46 |
+
notes: List[int]
|
47 |
+
description: str = ""
|
48 |
+
|
49 |
+
def __post_init__(self):
|
50 |
+
"""Validate scale notes are within MIDI range."""
|
51 |
+
for note in self.notes:
|
52 |
+
if not 0 <= note <= 127:
|
53 |
+
raise ValueError(f"MIDI note {note} out of range (0-127)")
|
54 |
+
|
55 |
+
|
56 |
+
@dataclass
|
57 |
+
class InstrumentMapping:
|
58 |
+
"""Maps a layer to an instrument and musical role."""
|
59 |
+
program: int # MIDI program number
|
60 |
+
role: MusicRole
|
61 |
+
channel: int
|
62 |
+
name: str = ""
|
63 |
+
|
64 |
+
def __post_init__(self):
|
65 |
+
"""Validate MIDI program and channel."""
|
66 |
+
if not 0 <= self.program <= 127:
|
67 |
+
raise ValueError(f"MIDI program {self.program} out of range")
|
68 |
+
if not 0 <= self.channel <= 15:
|
69 |
+
raise ValueError(f"MIDI channel {self.channel} out of range")
|
70 |
+
|
71 |
|
72 |
@dataclass
|
73 |
+
class GenerationConfig:
|
74 |
+
"""Complete configuration for music generation."""
|
75 |
model_name: str
|
76 |
+
compute_mode: ComputeMode
|
77 |
base_tempo: int
|
78 |
velocity_range: Tuple[int, int]
|
79 |
+
scale: ScaleDefinition
|
80 |
num_layers_limit: int
|
81 |
seed: int
|
82 |
+
instrument_preset: str
|
83 |
+
|
84 |
+
# Additional configuration options
|
85 |
+
quantization_grid: int = 120
|
86 |
+
octave_range: int = 2
|
87 |
+
dynamics_curve: str = "linear" # linear, exponential, logarithmic
|
88 |
+
|
89 |
+
def validate(self):
|
90 |
+
"""Validate configuration parameters."""
|
91 |
+
if not 1 <= self.base_tempo <= 2000:
|
92 |
+
raise ValueError("Tempo must be between 1 and 2000")
|
93 |
+
if not 1 <= self.velocity_range[0] < self.velocity_range[1] <= 127:
|
94 |
+
raise ValueError("Invalid velocity range")
|
95 |
+
if not 1 <= self.num_layers_limit <= 32:
|
96 |
+
raise ValueError("Number of layers must be between 1 and 32")
|
97 |
+
|
98 |
+
def to_dict(self) -> Dict:
|
99 |
+
"""Convert config to dictionary for serialization."""
|
100 |
+
return {
|
101 |
+
"model_name": self.model_name,
|
102 |
+
"compute_mode": self.compute_mode.value,
|
103 |
+
"base_tempo": self.base_tempo,
|
104 |
+
"velocity_range": self.velocity_range,
|
105 |
+
"scale_name": self.scale.name,
|
106 |
+
"scale_notes": self.scale.notes,
|
107 |
+
"num_layers_limit": self.num_layers_limit,
|
108 |
+
"seed": self.seed,
|
109 |
+
"instrument_preset": self.instrument_preset,
|
110 |
+
"quantization_grid": self.quantization_grid,
|
111 |
+
"octave_range": self.octave_range,
|
112 |
+
"dynamics_curve": self.dynamics_curve
|
113 |
+
}
|
114 |
+
|
115 |
+
@classmethod
|
116 |
+
def from_dict(cls, data: Dict, scale_manager: "ScaleManager") -> "GenerationConfig":
|
117 |
+
"""Create config from dictionary."""
|
118 |
+
scale = scale_manager.get_scale(data["scale_name"])
|
119 |
+
if scale is None:
|
120 |
+
scale = ScaleDefinition(name="Custom", notes=data["scale_notes"])
|
121 |
+
|
122 |
+
return cls(
|
123 |
+
model_name=data["model_name"],
|
124 |
+
compute_mode=ComputeMode(data["compute_mode"]),
|
125 |
+
base_tempo=data["base_tempo"],
|
126 |
+
velocity_range=tuple(data["velocity_range"]),
|
127 |
+
scale=scale,
|
128 |
+
num_layers_limit=data["num_layers_limit"],
|
129 |
+
seed=data["seed"],
|
130 |
+
instrument_preset=data["instrument_preset"],
|
131 |
+
quantization_grid=data.get("quantization_grid", 120),
|
132 |
+
octave_range=data.get("octave_range", 2),
|
133 |
+
dynamics_curve=data.get("dynamics_curve", "linear")
|
134 |
+
)
|
135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
@dataclass
|
138 |
class Latents:
|
139 |
+
"""Container for model latents."""
|
140 |
hidden_states: List[torch.Tensor]
|
141 |
attentions: List[torch.Tensor]
|
142 |
num_layers: int
|
143 |
num_tokens: int
|
144 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
145 |
+
|
146 |
+
|
147 |
+
# Music Components
|
148 |
+
|
149 |
+
class ScaleManager:
|
150 |
+
"""Manages musical scales and modes."""
|
151 |
+
|
152 |
+
def __init__(self):
|
153 |
+
"""Initialize with default scales."""
|
154 |
+
self.scales = {
|
155 |
+
"C pentatonic": ScaleDefinition(
|
156 |
+
"C pentatonic",
|
157 |
+
[60, 62, 65, 67, 70, 72, 74, 77],
|
158 |
+
"Major pentatonic scale"
|
159 |
+
),
|
160 |
+
"C major": ScaleDefinition(
|
161 |
+
"C major",
|
162 |
+
[60, 62, 64, 65, 67, 69, 71, 72],
|
163 |
+
"Major scale (Ionian mode)"
|
164 |
+
),
|
165 |
+
"A minor": ScaleDefinition(
|
166 |
+
"A minor",
|
167 |
+
[57, 59, 60, 62, 64, 65, 67, 69],
|
168 |
+
"Natural minor scale (Aeolian mode)"
|
169 |
+
),
|
170 |
+
"D dorian": ScaleDefinition(
|
171 |
+
"D dorian",
|
172 |
+
[62, 64, 65, 67, 69, 71, 72, 74],
|
173 |
+
"Dorian mode - minor with raised 6th"
|
174 |
+
),
|
175 |
+
"E phrygian": ScaleDefinition(
|
176 |
+
"E phrygian",
|
177 |
+
[64, 65, 67, 69, 71, 72, 74, 76],
|
178 |
+
"Phrygian mode - minor with lowered 2nd"
|
179 |
+
),
|
180 |
+
"G mixolydian": ScaleDefinition(
|
181 |
+
"G mixolydian",
|
182 |
+
[67, 69, 71, 72, 74, 76, 77, 79],
|
183 |
+
"Mixolydian mode - major with lowered 7th"
|
184 |
+
),
|
185 |
+
"Blues scale": ScaleDefinition(
|
186 |
+
"Blues scale",
|
187 |
+
[60, 63, 65, 66, 67, 70, 72, 75],
|
188 |
+
"Blues scale with blue notes"
|
189 |
+
),
|
190 |
+
"Chromatic": ScaleDefinition(
|
191 |
+
"Chromatic",
|
192 |
+
list(range(60, 72)),
|
193 |
+
"All 12 semitones"
|
194 |
+
)
|
195 |
+
}
|
196 |
+
|
197 |
+
def get_scale(self, name: str) -> Optional[ScaleDefinition]:
|
198 |
+
"""Get scale by name."""
|
199 |
+
return self.scales.get(name)
|
200 |
+
|
201 |
+
def add_custom_scale(self, name: str, notes: List[int], description: str = "") -> ScaleDefinition:
|
202 |
+
"""Add a custom scale."""
|
203 |
+
scale = ScaleDefinition(name, notes, description)
|
204 |
+
self.scales[name] = scale
|
205 |
+
return scale
|
206 |
+
|
207 |
+
def list_scales(self) -> List[str]:
|
208 |
+
"""Get list of available scale names."""
|
209 |
+
return list(self.scales.keys())
|
210 |
+
|
211 |
+
|
212 |
+
class InstrumentPresetManager:
|
213 |
+
"""Manages instrument presets for different musical styles."""
|
214 |
+
|
215 |
+
def __init__(self):
|
216 |
+
"""Initialize with default presets."""
|
217 |
+
self.presets = {
|
218 |
+
"Ensemble (melody+bass+pad etc.)": [
|
219 |
+
InstrumentMapping(0, MusicRole.MELODY, 0, "Piano"),
|
220 |
+
InstrumentMapping(33, MusicRole.BASS, 1, "Electric Bass"),
|
221 |
+
InstrumentMapping(46, MusicRole.HARMONY, 2, "Harp"),
|
222 |
+
InstrumentMapping(48, MusicRole.PAD, 3, "String Ensemble"),
|
223 |
+
InstrumentMapping(11, MusicRole.ACCENT, 4, "Vibraphone"),
|
224 |
+
InstrumentMapping(89, MusicRole.ATMOSPHERE, 5, "Pad Warm")
|
225 |
+
],
|
226 |
+
"Piano Trio (melody+bass+harmony)": [
|
227 |
+
InstrumentMapping(0, MusicRole.MELODY, 0, "Piano"),
|
228 |
+
InstrumentMapping(33, MusicRole.BASS, 1, "Electric Bass"),
|
229 |
+
InstrumentMapping(0, MusicRole.HARMONY, 2, "Piano"),
|
230 |
+
InstrumentMapping(48, MusicRole.PAD, 3, "String Ensemble"),
|
231 |
+
InstrumentMapping(0, MusicRole.ACCENT, 4, "Piano"),
|
232 |
+
InstrumentMapping(0, MusicRole.ATMOSPHERE, 5, "Piano")
|
233 |
+
],
|
234 |
+
"Pads & Atmosphere": [
|
235 |
+
InstrumentMapping(48, MusicRole.PAD, 0, "String Ensemble"),
|
236 |
+
InstrumentMapping(48, MusicRole.PAD, 1, "String Ensemble"),
|
237 |
+
InstrumentMapping(89, MusicRole.ATMOSPHERE, 2, "Pad Warm"),
|
238 |
+
InstrumentMapping(89, MusicRole.ATMOSPHERE, 3, "Pad Warm"),
|
239 |
+
InstrumentMapping(46, MusicRole.HARMONY, 4, "Harp"),
|
240 |
+
InstrumentMapping(11, MusicRole.ACCENT, 5, "Vibraphone")
|
241 |
+
],
|
242 |
+
"Orchestral": [
|
243 |
+
InstrumentMapping(40, MusicRole.MELODY, 0, "Violin"),
|
244 |
+
InstrumentMapping(42, MusicRole.BASS, 1, "Cello"),
|
245 |
+
InstrumentMapping(46, MusicRole.HARMONY, 2, "Harp"),
|
246 |
+
InstrumentMapping(48, MusicRole.PAD, 3, "String Ensemble"),
|
247 |
+
InstrumentMapping(73, MusicRole.ACCENT, 4, "Flute"),
|
248 |
+
InstrumentMapping(49, MusicRole.ATMOSPHERE, 5, "Slow Strings")
|
249 |
+
],
|
250 |
+
"Electronic": [
|
251 |
+
InstrumentMapping(80, MusicRole.MELODY, 0, "Lead Square"),
|
252 |
+
InstrumentMapping(38, MusicRole.BASS, 1, "Synth Bass"),
|
253 |
+
InstrumentMapping(81, MusicRole.HARMONY, 2, "Lead Sawtooth"),
|
254 |
+
InstrumentMapping(90, MusicRole.PAD, 3, "Pad Polysynth"),
|
255 |
+
InstrumentMapping(82, MusicRole.ACCENT, 4, "Lead Calliope"),
|
256 |
+
InstrumentMapping(91, MusicRole.ATMOSPHERE, 5, "Pad Bowed")
|
257 |
+
]
|
258 |
+
}
|
259 |
+
|
260 |
+
def get_preset(self, name: str) -> List[InstrumentMapping]:
|
261 |
+
"""Get instrument preset by name."""
|
262 |
+
return self.presets.get(name, self.presets["Ensemble (melody+bass+pad etc.)"])
|
263 |
+
|
264 |
+
def list_presets(self) -> List[str]:
|
265 |
+
"""Get list of available preset names."""
|
266 |
+
return list(self.presets.keys())
|
267 |
+
|
268 |
+
|
269 |
+
# Music Generation Components
|
270 |
+
|
271 |
+
class MusicMathUtils:
|
272 |
+
"""Utility class for music-related mathematical operations."""
|
273 |
+
|
274 |
+
@staticmethod
|
275 |
+
def entropy(p: np.ndarray) -> float:
|
276 |
+
"""Calculate Shannon entropy of a probability distribution."""
|
277 |
+
p = p / (p.sum() + 1e-9)
|
278 |
+
return float(-np.sum(p * np.log2(p + 1e-9)))
|
279 |
+
|
280 |
+
@staticmethod
|
281 |
+
def quantize_time(time_val: int, grid: int = 120) -> int:
|
282 |
+
"""Quantize time value to grid."""
|
283 |
+
return int(round(time_val / grid) * grid)
|
284 |
+
|
285 |
+
@staticmethod
|
286 |
+
def norm_to_scale(val: float, scale: np.ndarray, octave_range: int = 2) -> int:
|
287 |
+
"""Map normalized value to scale note with octave range."""
|
288 |
+
octave = int(abs(val) * octave_range) * 12
|
289 |
+
note_idx = int(abs(val * 100) % len(scale))
|
290 |
+
return int(scale[note_idx] + octave)
|
291 |
+
|
292 |
+
@staticmethod
|
293 |
+
def apply_dynamics_curve(value: float, curve_type: str = "linear") -> float:
|
294 |
+
"""Apply dynamics curve to a value."""
|
295 |
+
value = np.clip(value, 0, 1)
|
296 |
+
if curve_type == "exponential":
|
297 |
+
return value ** 2
|
298 |
+
elif curve_type == "logarithmic":
|
299 |
+
return np.log1p(value * np.e) / np.log1p(np.e)
|
300 |
+
else: # linear
|
301 |
+
return value
|
302 |
+
|
303 |
+
|
304 |
+
class NoteGenerator:
|
305 |
+
"""Generates notes based on neural network latents."""
|
306 |
+
|
307 |
+
# Role-specific frequency multipliers
|
308 |
+
ROLE_FREQUENCIES = {
|
309 |
+
MusicRole.MELODY: 2.0,
|
310 |
+
MusicRole.BASS: 0.5,
|
311 |
+
MusicRole.HARMONY: 1.5,
|
312 |
+
MusicRole.PAD: 0.25,
|
313 |
+
MusicRole.ACCENT: 3.0,
|
314 |
+
MusicRole.ATMOSPHERE: 0.33
|
315 |
+
}
|
316 |
+
|
317 |
+
# Role-specific weight distributions
|
318 |
+
ROLE_WEIGHTS = {
|
319 |
+
MusicRole.MELODY: np.array([0.4, 0.2, 0.2, 0.1, 0.1]),
|
320 |
+
MusicRole.BASS: np.array([0.1, 0.4, 0.1, 0.3, 0.1]),
|
321 |
+
MusicRole.HARMONY: np.array([0.2, 0.2, 0.3, 0.2, 0.1]),
|
322 |
+
MusicRole.PAD: np.array([0.1, 0.3, 0.1, 0.1, 0.4]),
|
323 |
+
MusicRole.ACCENT: np.array([0.5, 0.1, 0.2, 0.1, 0.1]),
|
324 |
+
MusicRole.ATMOSPHERE: np.array([0.1, 0.2, 0.1, 0.2, 0.4])
|
325 |
+
}
|
326 |
+
|
327 |
+
def __init__(self, config: GenerationConfig):
|
328 |
+
"""Initialize with generation configuration."""
|
329 |
+
self.config = config
|
330 |
+
self.math_utils = MusicMathUtils()
|
331 |
+
self.history: Dict[int, int] = {}
|
332 |
+
|
333 |
+
def create_note_probability(
|
334 |
+
self,
|
335 |
+
layer_idx: int,
|
336 |
+
token_idx: int,
|
337 |
+
attention_val: float,
|
338 |
+
hidden_state: np.ndarray,
|
339 |
+
num_tokens: int,
|
340 |
+
role: MusicRole
|
341 |
+
) -> float:
|
342 |
+
"""Calculate probability of playing a note based on multiple factors."""
|
343 |
+
# Base probability from attention
|
344 |
+
base_prob = 1 / (1 + np.exp(-10 * (attention_val - 0.5)))
|
345 |
+
|
346 |
+
# Temporal factor based on role frequency
|
347 |
+
temporal_factor = 0.5 + 0.5 * np.sin(
|
348 |
+
2 * np.pi * self.ROLE_FREQUENCIES[role] * token_idx / max(1, num_tokens)
|
349 |
+
)
|
350 |
+
|
351 |
+
# Energy factor from hidden state norm
|
352 |
+
energy = np.linalg.norm(hidden_state)
|
353 |
+
energy_factor = np.tanh(energy / 10)
|
354 |
+
|
355 |
+
# Variance factor
|
356 |
+
local_variance = np.var(hidden_state)
|
357 |
+
variance_factor = 1 - np.exp(-local_variance)
|
358 |
+
|
359 |
+
# Entropy factor
|
360 |
+
state_entropy = self.math_utils.entropy(np.abs(hidden_state))
|
361 |
+
max_entropy = np.log2(max(2, hidden_state.shape[0]))
|
362 |
+
entropy_factor = state_entropy / max_entropy
|
363 |
+
|
364 |
+
# Combine factors with role-specific weights
|
365 |
+
factors = np.array([base_prob, temporal_factor, energy_factor, variance_factor, entropy_factor])
|
366 |
+
weights = self.ROLE_WEIGHTS[role]
|
367 |
+
combined_prob = float(np.dot(weights, factors))
|
368 |
+
|
369 |
+
# Add deterministic noise for variation
|
370 |
+
noise_seed = layer_idx * 1000 + token_idx
|
371 |
+
noise = 0.1 * (np.sin(noise_seed * 0.1) + np.cos(noise_seed * 0.23)) / 2
|
372 |
+
|
373 |
+
# Apply dynamics curve
|
374 |
+
final_prob = (combined_prob + noise) ** 1.5
|
375 |
+
final_prob = self.math_utils.apply_dynamics_curve(final_prob, self.config.dynamics_curve)
|
376 |
+
|
377 |
+
return float(np.clip(final_prob, 0, 1))
|
378 |
+
|
379 |
+
def should_play_note(
|
380 |
+
self,
|
381 |
+
layer_idx: int,
|
382 |
+
token_idx: int,
|
383 |
+
attention_val: float,
|
384 |
+
hidden_state: np.ndarray,
|
385 |
+
num_tokens: int,
|
386 |
+
role: MusicRole
|
387 |
+
) -> bool:
|
388 |
+
"""Determine if a note should be played."""
|
389 |
+
prob = self.create_note_probability(
|
390 |
+
layer_idx, token_idx, attention_val, hidden_state, num_tokens, role
|
391 |
+
)
|
392 |
+
|
393 |
+
# Adjust probability based on silence duration
|
394 |
+
if layer_idx in self.history:
|
395 |
+
last_played = self.history[layer_idx]
|
396 |
+
silence_duration = token_idx - last_played
|
397 |
+
prob *= (1 + np.tanh(silence_duration / 5) * 0.5)
|
398 |
+
|
399 |
+
# Stochastic decision
|
400 |
+
play_note = np.random.random() < prob
|
401 |
+
|
402 |
+
if play_note:
|
403 |
+
self.history[layer_idx] = token_idx
|
404 |
+
|
405 |
+
return play_note
|
406 |
+
|
407 |
+
def generate_notes_for_role(
|
408 |
+
self,
|
409 |
+
role: MusicRole,
|
410 |
+
hidden_state: np.ndarray,
|
411 |
+
scale: np.ndarray
|
412 |
+
) -> List[int]:
|
413 |
+
"""Generate notes based on role and hidden state."""
|
414 |
+
if role == MusicRole.MELODY:
|
415 |
+
note = self.math_utils.norm_to_scale(
|
416 |
+
hidden_state[0], scale, octave_range=1
|
417 |
+
)
|
418 |
+
return [note]
|
419 |
+
|
420 |
+
elif role == MusicRole.BASS:
|
421 |
+
note = self.math_utils.norm_to_scale(
|
422 |
+
hidden_state[0], scale, octave_range=0
|
423 |
+
) - 12
|
424 |
+
return [note]
|
425 |
+
|
426 |
+
elif role == MusicRole.HARMONY:
|
427 |
+
return [
|
428 |
+
self.math_utils.norm_to_scale(hidden_state[i], scale, octave_range=1)
|
429 |
+
for i in range(0, min(2, len(hidden_state)), 1)
|
430 |
+
]
|
431 |
+
|
432 |
+
elif role == MusicRole.PAD:
|
433 |
+
return [
|
434 |
+
self.math_utils.norm_to_scale(hidden_state[i], scale, octave_range=1)
|
435 |
+
for i in range(0, min(3, len(hidden_state)), 2)
|
436 |
+
]
|
437 |
+
|
438 |
+
elif role == MusicRole.ACCENT:
|
439 |
+
note = self.math_utils.norm_to_scale(
|
440 |
+
hidden_state[0], scale, octave_range=2
|
441 |
+
) + 12
|
442 |
+
return [note]
|
443 |
+
|
444 |
+
else: # ATMOSPHERE
|
445 |
+
return [
|
446 |
+
self.math_utils.norm_to_scale(hidden_state[i], scale, octave_range=1)
|
447 |
+
for i in range(0, min(2, len(hidden_state)), 3)
|
448 |
+
]
|
449 |
+
|
450 |
+
def calculate_velocity(
|
451 |
+
self,
|
452 |
+
role: MusicRole,
|
453 |
+
attention_strength: float
|
454 |
+
) -> int:
|
455 |
+
"""Calculate note velocity based on role and attention."""
|
456 |
+
base_velocity = int(
|
457 |
+
attention_strength * (self.config.velocity_range[1] - self.config.velocity_range[0])
|
458 |
+
+ self.config.velocity_range[0]
|
459 |
+
)
|
460 |
+
|
461 |
+
# Role-specific adjustments
|
462 |
+
if role == MusicRole.MELODY:
|
463 |
+
velocity = min(base_velocity + 10, 127)
|
464 |
+
elif role == MusicRole.ACCENT:
|
465 |
+
velocity = min(base_velocity + 20, 127)
|
466 |
+
elif role in [MusicRole.PAD, MusicRole.ATMOSPHERE]:
|
467 |
+
velocity = max(base_velocity - 10, 20)
|
468 |
+
else:
|
469 |
+
velocity = base_velocity
|
470 |
+
|
471 |
+
return velocity
|
472 |
+
|
473 |
+
def calculate_duration(
|
474 |
+
self,
|
475 |
+
role: MusicRole,
|
476 |
+
attention_matrix: np.ndarray
|
477 |
+
) -> int:
|
478 |
+
"""Calculate note duration based on role and attention."""
|
479 |
+
if role in [MusicRole.PAD, MusicRole.ATMOSPHERE]:
|
480 |
+
duration = self.config.base_tempo * 4
|
481 |
+
elif role == MusicRole.BASS:
|
482 |
+
duration = self.config.base_tempo
|
483 |
+
else:
|
484 |
+
try:
|
485 |
+
dur_factor = self.math_utils.entropy(attention_matrix.mean(axis=0)) / (
|
486 |
+
np.log2(attention_matrix.shape[-1]) + 1e-9
|
487 |
+
)
|
488 |
+
except Exception:
|
489 |
+
dur_factor = 0.5
|
490 |
+
duration = self.math_utils.quantize_time(
|
491 |
+
int(self.config.base_tempo * (0.5 + dur_factor * 1.5)),
|
492 |
+
self.config.quantization_grid
|
493 |
+
)
|
494 |
+
|
495 |
+
return duration
|
496 |
+
|
497 |
+
|
498 |
+
# Model Interaction
|
499 |
+
|
500 |
+
class LatentExtractor(ABC):
|
501 |
+
"""Abstract base class for latent extraction strategies."""
|
502 |
+
|
503 |
+
@abstractmethod
|
504 |
+
def extract(self, text: str, config: GenerationConfig, progress=None) -> Latents:
|
505 |
+
"""Extract latents from text."""
|
506 |
pass
|
507 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
508 |
|
509 |
+
class MockLatentExtractor(LatentExtractor):
|
510 |
+
"""Generate mock latents for testing without loading models."""
|
511 |
+
|
512 |
+
def extract(self, text: str, config: GenerationConfig, progress=None) -> Latents:
|
513 |
+
"""Generate synthetic latents based on text."""
|
514 |
+
# Simulate token count based on text length
|
515 |
+
tokens = max(16, min(128, len(text.split()) * 4))
|
516 |
+
layers = min(config.num_layers_limit, 6)
|
517 |
+
|
518 |
+
# Generate deterministic but varied latents based on text
|
519 |
+
np.random.seed(hash(text) % 2**32)
|
520 |
+
|
521 |
+
hidden_states = [
|
522 |
+
torch.randn(1, tokens, 128) for _ in range(layers)
|
523 |
+
]
|
524 |
+
attentions = [
|
525 |
+
torch.rand(1, 8, tokens, tokens) for _ in range(layers)
|
526 |
+
]
|
527 |
+
|
528 |
+
metadata = {
|
529 |
+
"mode": "mock",
|
530 |
+
"text_length": len(text),
|
531 |
+
"generated_tokens": tokens,
|
532 |
+
"generated_layers": layers
|
533 |
+
}
|
534 |
+
|
535 |
+
return Latents(
|
536 |
+
hidden_states=hidden_states,
|
537 |
+
attentions=attentions,
|
538 |
+
num_layers=layers,
|
539 |
+
num_tokens=tokens,
|
540 |
+
metadata=metadata
|
541 |
+
)
|
542 |
+
|
543 |
+
|
544 |
+
class ModelLatentExtractor(LatentExtractor):
|
545 |
+
"""Extract real latents from transformer models."""
|
546 |
+
|
547 |
+
@spaces.GPU(duration=45)
|
548 |
+
def extract(self, text: str, config: GenerationConfig, progress=None) -> Latents:
|
549 |
+
"""Extract latents from a real transformer model."""
|
550 |
+
model_name = config.model_name
|
551 |
+
|
552 |
+
# Load tokenizer
|
553 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
554 |
+
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
555 |
+
tokenizer.pad_token = tokenizer.eos_token
|
556 |
+
|
557 |
+
# Configure model loading
|
558 |
+
load_kwargs = {
|
559 |
+
"output_hidden_states": True,
|
560 |
+
"output_attentions": True,
|
561 |
+
"device_map": "cuda" if torch.cuda.is_available() else "cpu",
|
562 |
+
}
|
563 |
+
|
564 |
+
# Set appropriate dtype
|
565 |
+
try:
|
566 |
+
load_kwargs["torch_dtype"] = (
|
567 |
+
torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
568 |
+
)
|
569 |
+
except Exception:
|
570 |
+
pass
|
571 |
+
|
572 |
+
# Load model
|
573 |
+
model = AutoModel.from_pretrained(model_name, **load_kwargs)
|
574 |
+
|
575 |
+
# Tokenize input
|
576 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
|
577 |
+
device = next(model.parameters()).device
|
578 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
579 |
+
|
580 |
+
# Get model outputs
|
581 |
+
with torch.no_grad():
|
582 |
+
outputs = model(**inputs)
|
583 |
+
hidden_states = list(outputs.hidden_states)
|
584 |
+
attentions = list(outputs.attentions)
|
585 |
+
|
586 |
+
# Move to CPU to free VRAM
|
587 |
+
hidden_states = [hs.to("cpu") for hs in hidden_states]
|
588 |
+
attentions = [att.to("cpu") for att in attentions]
|
589 |
+
|
590 |
+
# Limit layers
|
591 |
+
layers = min(config.num_layers_limit, len(hidden_states))
|
592 |
+
tokens = hidden_states[0].shape[1]
|
593 |
+
|
594 |
+
# Clean up
|
595 |
+
try:
|
596 |
+
del model
|
597 |
+
if torch.cuda.is_available():
|
598 |
+
torch.cuda.empty_cache()
|
599 |
+
gc.collect()
|
600 |
+
except Exception:
|
601 |
+
pass
|
602 |
+
|
603 |
+
metadata = {
|
604 |
+
"mode": "full_model",
|
605 |
+
"model_name": model_name,
|
606 |
+
"actual_layers": len(hidden_states),
|
607 |
+
"used_layers": layers,
|
608 |
+
"tokens": tokens
|
609 |
+
}
|
610 |
+
|
611 |
+
return Latents(
|
612 |
+
hidden_states=hidden_states[:layers],
|
613 |
+
attentions=attentions[:layers],
|
614 |
+
num_layers=layers,
|
615 |
+
num_tokens=tokens,
|
616 |
+
metadata=metadata
|
617 |
+
)
|
618 |
+
|
619 |
+
|
620 |
+
class LatentExtractorFactory:
|
621 |
+
"""Factory for creating appropriate latent extractors."""
|
622 |
+
|
623 |
+
@staticmethod
|
624 |
+
def create(compute_mode: ComputeMode) -> LatentExtractor:
|
625 |
+
"""Create a latent extractor based on compute mode."""
|
626 |
+
if compute_mode == ComputeMode.MOCK_LATENTS:
|
627 |
+
return MockLatentExtractor()
|
628 |
+
else:
|
629 |
+
return ModelLatentExtractor()
|
630 |
+
|
631 |
+
|
632 |
+
# MIDI Generation
|
633 |
+
|
634 |
+
class MIDIRenderer:
|
635 |
+
"""Renders MIDI files from latents."""
|
636 |
+
|
637 |
+
def __init__(self, config: GenerationConfig, instrument_manager: InstrumentPresetManager):
|
638 |
+
"""Initialize MIDI renderer."""
|
639 |
+
self.config = config
|
640 |
+
self.instrument_manager = instrument_manager
|
641 |
+
self.note_generator = NoteGenerator(config)
|
642 |
+
self.math_utils = MusicMathUtils()
|
643 |
+
|
644 |
+
def render(self, latents: Latents) -> Tuple[bytes, Dict[str, Any]]:
|
645 |
+
"""Render MIDI from latents."""
|
646 |
+
# Set random seeds for reproducibility
|
647 |
+
np.random.seed(self.config.seed)
|
648 |
+
random.seed(self.config.seed)
|
649 |
+
torch.manual_seed(self.config.seed)
|
650 |
+
|
651 |
+
# Prepare data
|
652 |
+
scale = np.array(self.config.scale.notes, dtype=int)
|
653 |
+
num_layers = latents.num_layers
|
654 |
+
num_tokens = latents.num_tokens
|
655 |
+
|
656 |
+
# Convert tensors to numpy
|
657 |
+
hidden_states = [
|
658 |
+
hs.float().numpy() if isinstance(hs, torch.Tensor) else hs
|
659 |
+
for hs in latents.hidden_states
|
660 |
+
]
|
661 |
+
attentions = [
|
662 |
+
att.float().numpy() if isinstance(att, torch.Tensor) else att
|
663 |
+
for att in latents.attentions
|
664 |
+
]
|
665 |
+
|
666 |
+
# Get instrument mappings
|
667 |
+
instrument_mappings = self.instrument_manager.get_preset(self.config.instrument_preset)
|
668 |
+
|
669 |
+
# Create MIDI file and tracks
|
670 |
+
midi_file = MidiFile()
|
671 |
+
tracks = self._create_tracks(midi_file, num_layers, instrument_mappings)
|
672 |
+
|
673 |
+
# Generate notes
|
674 |
+
stats = self._generate_notes(
|
675 |
+
tracks, hidden_states, attentions,
|
676 |
+
scale, num_tokens, instrument_mappings
|
677 |
+
)
|
678 |
+
|
679 |
+
# Convert to bytes
|
680 |
+
bio = io.BytesIO()
|
681 |
+
midi_file.save(file=bio)
|
682 |
+
bio.seek(0)
|
683 |
+
|
684 |
+
# Prepare metadata
|
685 |
+
metadata = {
|
686 |
+
"config": self.config.to_dict(),
|
687 |
+
"latents_info": latents.metadata,
|
688 |
+
"stats": stats,
|
689 |
+
"timestamp": time.time()
|
690 |
+
}
|
691 |
+
|
692 |
+
return bio.read(), metadata
|
693 |
+
|
694 |
+
def _create_tracks(
|
695 |
+
self,
|
696 |
+
midi_file: MidiFile,
|
697 |
+
num_layers: int,
|
698 |
+
instrument_mappings: List[InstrumentMapping]
|
699 |
+
) -> List[MidiTrack]:
|
700 |
+
"""Create MIDI tracks with instrument assignments."""
|
701 |
+
tracks = []
|
702 |
+
|
703 |
for layer_idx in range(num_layers):
|
704 |
+
track = MidiTrack()
|
705 |
+
midi_file.tracks.append(track)
|
706 |
+
tracks.append(track)
|
707 |
+
|
708 |
+
# Get instrument mapping for this layer
|
709 |
+
if layer_idx < len(instrument_mappings):
|
710 |
+
mapping = instrument_mappings[layer_idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
711 |
else:
|
712 |
+
# Default to piano if not enough mappings
|
713 |
+
mapping = InstrumentMapping(0, MusicRole.MELODY, layer_idx % 16)
|
714 |
+
|
715 |
+
# Set instrument
|
716 |
+
track.append(Message(
|
717 |
+
"program_change",
|
718 |
+
program=mapping.program,
|
719 |
+
time=0,
|
720 |
+
channel=mapping.channel
|
721 |
+
))
|
722 |
+
|
723 |
+
# Add track name
|
724 |
+
if mapping.name:
|
725 |
+
track.append(mido.MetaMessage(
|
726 |
+
"track_name",
|
727 |
+
name=f"{mapping.name} - {mapping.role.value}",
|
728 |
+
time=0
|
729 |
+
))
|
730 |
+
|
731 |
+
return tracks
|
732 |
+
|
733 |
+
def _generate_notes(
|
734 |
+
self,
|
735 |
+
tracks: List[MidiTrack],
|
736 |
+
hidden_states: List[np.ndarray],
|
737 |
+
attentions: List[np.ndarray],
|
738 |
+
scale: np.ndarray,
|
739 |
+
num_tokens: int,
|
740 |
+
instrument_mappings: List[InstrumentMapping]
|
741 |
+
) -> Dict[str, Any]:
|
742 |
+
"""Generate notes for all tracks."""
|
743 |
+
current_time = [0] * len(tracks)
|
744 |
+
notes_count = [0] * len(tracks)
|
745 |
+
|
746 |
+
for token_idx in range(num_tokens):
|
747 |
+
# Update time periodically
|
748 |
+
if token_idx > 0 and token_idx % 4 == 0:
|
749 |
+
for layer_idx in range(len(tracks)):
|
750 |
+
current_time[layer_idx] += self.config.base_tempo
|
751 |
+
|
752 |
+
# Calculate panning
|
753 |
+
pan = 64 + int(32 * np.sin(token_idx * math.pi / max(1, num_tokens)))
|
754 |
+
|
755 |
+
# Generate notes for each layer
|
756 |
+
for layer_idx in range(len(tracks)):
|
757 |
+
if layer_idx >= len(instrument_mappings):
|
758 |
+
continue
|
759 |
+
|
760 |
+
mapping = instrument_mappings[layer_idx]
|
761 |
+
|
762 |
+
# Get attention and hidden state
|
763 |
+
attn_matrix = attentions[min(layer_idx, len(attentions) - 1)][0, :, token_idx, :]
|
764 |
+
attention_strength = float(np.mean(attn_matrix))
|
765 |
+
layer_vec = hidden_states[layer_idx][0, token_idx]
|
766 |
+
|
767 |
+
# Check if note should be played
|
768 |
+
if not self.note_generator.should_play_note(
|
769 |
+
layer_idx, token_idx, attention_strength,
|
770 |
+
layer_vec, num_tokens, mapping.role
|
771 |
+
):
|
772 |
+
continue
|
773 |
+
|
774 |
+
# Generate notes
|
775 |
+
notes_to_play = self.note_generator.generate_notes_for_role(
|
776 |
+
mapping.role, layer_vec, scale
|
777 |
+
)
|
778 |
+
|
779 |
+
# Calculate velocity and duration
|
780 |
+
velocity = self.note_generator.calculate_velocity(
|
781 |
+
mapping.role, attention_strength
|
782 |
+
)
|
783 |
+
duration = self.note_generator.calculate_duration(
|
784 |
+
mapping.role, attn_matrix
|
785 |
+
)
|
786 |
+
|
787 |
+
# Add notes to track
|
788 |
+
for note in notes_to_play:
|
789 |
+
note = max(21, min(108, int(note))) # Clamp to piano range
|
790 |
+
|
791 |
+
tracks[layer_idx].append(Message(
|
792 |
+
"note_on",
|
793 |
+
note=note,
|
794 |
+
velocity=velocity,
|
795 |
+
time=current_time[layer_idx],
|
796 |
+
channel=mapping.channel
|
797 |
+
))
|
798 |
+
|
799 |
+
tracks[layer_idx].append(Message(
|
800 |
+
"note_off",
|
801 |
+
note=note,
|
802 |
+
velocity=0,
|
803 |
+
time=duration,
|
804 |
+
channel=mapping.channel
|
805 |
+
))
|
806 |
+
|
807 |
+
current_time[layer_idx] = 0
|
808 |
+
notes_count[layer_idx] += 1
|
809 |
+
|
810 |
+
# Set panning on first token
|
811 |
+
if token_idx == 0:
|
812 |
+
tracks[layer_idx].append(Message(
|
813 |
+
"control_change",
|
814 |
+
control=10,
|
815 |
+
value=pan,
|
816 |
+
time=0,
|
817 |
+
channel=mapping.channel
|
818 |
+
))
|
819 |
+
|
820 |
+
return {
|
821 |
+
"num_layers": len(tracks),
|
822 |
+
"num_tokens": num_tokens,
|
823 |
+
"notes_per_layer": notes_count,
|
824 |
+
"total_notes": int(sum(notes_count)),
|
825 |
+
"tempo_ticks_per_beat": int(self.config.base_tempo),
|
826 |
+
"scale": list(map(int, scale.tolist())),
|
827 |
+
}
|
828 |
+
|
829 |
+
|
830 |
+
# Main Orchestrator
|
831 |
+
|
832 |
+
class LLMForestOrchestra:
|
833 |
+
"""Main orchestrator class that coordinates the entire pipeline."""
|
834 |
+
|
835 |
+
DEFAULT_MODEL = "unsloth/Qwen3-14B-Base"
|
836 |
+
|
837 |
+
def __init__(self):
|
838 |
+
"""Initialize the orchestra."""
|
839 |
+
self.scale_manager = ScaleManager()
|
840 |
+
self.instrument_manager = InstrumentPresetManager()
|
841 |
+
self.saved_configs: Dict[str, GenerationConfig] = {}
|
842 |
+
|
843 |
+
def generate(
|
844 |
+
self,
|
845 |
+
text: str,
|
846 |
+
model_name: str,
|
847 |
+
compute_mode: str,
|
848 |
+
base_tempo: int,
|
849 |
+
velocity_range: Tuple[int, int],
|
850 |
+
scale_name: str,
|
851 |
+
custom_scale_notes: Optional[List[int]],
|
852 |
+
num_layers: int,
|
853 |
+
instrument_preset: str,
|
854 |
+
seed: int,
|
855 |
+
quantization_grid: int = 120,
|
856 |
+
octave_range: int = 2,
|
857 |
+
dynamics_curve: str = "linear"
|
858 |
+
) -> Tuple[str, Dict[str, Any]]:
|
859 |
+
"""Generate MIDI from text input."""
|
860 |
+
# Get or create scale
|
861 |
+
if scale_name == "Custom":
|
862 |
+
if not custom_scale_notes:
|
863 |
+
raise ValueError("Custom scale requires note list")
|
864 |
+
scale = ScaleDefinition("Custom", custom_scale_notes)
|
865 |
+
else:
|
866 |
+
scale = self.scale_manager.get_scale(scale_name)
|
867 |
+
if scale is None:
|
868 |
+
raise ValueError(f"Unknown scale: {scale_name}")
|
869 |
+
|
870 |
+
# Create configuration
|
871 |
+
config = GenerationConfig(
|
872 |
+
model_name=model_name or self.DEFAULT_MODEL,
|
873 |
+
compute_mode=ComputeMode(compute_mode),
|
874 |
+
base_tempo=base_tempo,
|
875 |
+
velocity_range=velocity_range,
|
876 |
+
scale=scale,
|
877 |
+
num_layers_limit=num_layers,
|
878 |
+
seed=seed,
|
879 |
+
instrument_preset=instrument_preset,
|
880 |
+
quantization_grid=quantization_grid,
|
881 |
+
octave_range=octave_range,
|
882 |
+
dynamics_curve=dynamics_curve
|
883 |
+
)
|
884 |
+
|
885 |
+
# Validate configuration
|
886 |
+
config.validate()
|
887 |
+
|
888 |
+
# Extract latents
|
889 |
+
extractor = LatentExtractorFactory.create(config.compute_mode)
|
890 |
+
latents = extractor.extract(text, config)
|
891 |
+
|
892 |
+
# Render MIDI
|
893 |
+
renderer = MIDIRenderer(config, self.instrument_manager)
|
894 |
+
midi_bytes, metadata = renderer.render(latents)
|
895 |
+
|
896 |
+
# Save MIDI file
|
897 |
+
filename = f"llm_forest_orchestra_{uuid.uuid4().hex[:8]}.mid"
|
898 |
+
with open(filename, "wb") as f:
|
899 |
+
f.write(midi_bytes)
|
900 |
+
|
901 |
+
return filename, metadata
|
902 |
+
|
903 |
+
def save_config(self, name: str, config: GenerationConfig):
|
904 |
+
"""Save a configuration for later use."""
|
905 |
+
self.saved_configs[name] = config
|
906 |
+
|
907 |
+
def load_config(self, name: str) -> Optional[GenerationConfig]:
|
908 |
+
"""Load a saved configuration."""
|
909 |
+
return self.saved_configs.get(name)
|
910 |
+
|
911 |
+
def export_config(self, config: GenerationConfig, filepath: str):
|
912 |
+
"""Export configuration to JSON file."""
|
913 |
+
with open(filepath, "w") as f:
|
914 |
+
json.dump(config.to_dict(), f, indent=2)
|
915 |
+
|
916 |
+
def import_config(self, filepath: str) -> GenerationConfig:
|
917 |
+
"""Import configuration from JSON file."""
|
918 |
+
with open(filepath, "r") as f:
|
919 |
+
data = json.load(f)
|
920 |
+
return GenerationConfig.from_dict(data, self.scale_manager)
|
921 |
+
|
922 |
+
|
923 |
+
# Gradio UI
|
924 |
+
|
925 |
+
class GradioInterface:
|
926 |
+
"""Manages the Gradio user interface."""
|
927 |
+
|
928 |
+
DESCRIPTION = """
|
929 |
+
# 🌲 LLM Forest Orchestra — Sonify Transformer Internals
|
930 |
+
|
931 |
+
Transform the hidden states and attention patterns of language models into multi-layered musical compositions.
|
932 |
+
|
933 |
+
## 🍄 Inspiration
|
934 |
+
|
935 |
+
This project is inspired by the way **mushrooms and mycelial networks in forests**
|
936 |
+
connect plants and trees, forming a living web of communication and resource sharing.
|
937 |
+
These connections, can be turned into ethereal music.
|
938 |
+
Just as signals move through these hidden connections, transformer models also
|
939 |
+
pass hidden states and attentions across their layers. Here, those hidden
|
940 |
+
connections are translated into **music**, analogous to the forest's secret orchestra.
|
941 |
+
|
942 |
+
## Features
|
943 |
+
- **Two compute modes**: Full model (GPU) or Mock latents (CPU-friendly)
|
944 |
+
- **Multiple musical scales**: From pentatonic to chromatic
|
945 |
+
- **Instrument presets**: Orchestral, electronic, ensemble, and more
|
946 |
+
- **Advanced controls**: Dynamics curves, quantization, velocity ranges
|
947 |
+
- **Export**: Standard MIDI files for further editing in your DAW
|
948 |
+
"""
|
949 |
+
|
950 |
+
EXAMPLE_TEXT = """Joy cascades in golden waterfalls, crashing into pools of melancholy blue.
|
951 |
+
Anger burns red through veins of marble, while serenity floats on clouds of softest grey.
|
952 |
+
Love pulses in waves of crimson and rose, intertwining with longing's purple haze.
|
953 |
+
Each feeling resonates at its own frequency, painting music across the soul's canvas."""
|
954 |
+
|
955 |
+
def __init__(self, orchestra: LLMForestOrchestra):
|
956 |
+
"""Initialize the interface."""
|
957 |
+
self.orchestra = orchestra
|
958 |
+
|
959 |
+
def create_interface(self) -> gr.Blocks:
|
960 |
+
"""Create the Gradio interface."""
|
961 |
+
with gr.Blocks(title="LLM Forest Orchestra", theme=gr.themes.Soft()) as demo:
|
962 |
+
gr.Markdown(self.DESCRIPTION)
|
963 |
+
|
964 |
+
with gr.Tabs():
|
965 |
+
with gr.TabItem("🎵 Generate Music"):
|
966 |
+
self._create_generation_tab()
|
967 |
+
|
968 |
+
return demo
|
969 |
+
|
970 |
+
def _create_generation_tab(self):
|
971 |
+
"""Create the main generation tab."""
|
972 |
+
with gr.Row():
|
973 |
+
with gr.Column(scale=1):
|
974 |
+
text_input = gr.Textbox(
|
975 |
+
value=self.EXAMPLE_TEXT,
|
976 |
+
label="Input Text",
|
977 |
+
lines=8,
|
978 |
+
placeholder="Enter text to sonify..."
|
979 |
+
)
|
980 |
+
|
981 |
+
model_name = gr.Textbox(
|
982 |
+
value=self.orchestra.DEFAULT_MODEL,
|
983 |
+
label="Hugging Face Model",
|
984 |
+
info="Model must support output_hidden_states and output_attentions"
|
985 |
+
)
|
986 |
+
|
987 |
+
compute_mode = gr.Radio(
|
988 |
+
choices=["Full model", "Mock latents"],
|
989 |
+
value="Mock latents",
|
990 |
+
label="Compute Mode",
|
991 |
+
info="Mock latents for quick CPU-only demo"
|
992 |
+
)
|
993 |
+
|
994 |
+
with gr.Row():
|
995 |
+
instrument_preset = gr.Dropdown(
|
996 |
+
choices=self.orchestra.instrument_manager.list_presets(),
|
997 |
+
value="Ensemble (melody+bass+pad etc.)",
|
998 |
+
label="Instrument Preset"
|
999 |
+
)
|
1000 |
+
|
1001 |
+
scale_choice = gr.Dropdown(
|
1002 |
+
choices=self.orchestra.scale_manager.list_scales() + ["Custom"],
|
1003 |
+
value="C pentatonic",
|
1004 |
+
label="Musical Scale"
|
1005 |
+
)
|
1006 |
+
|
1007 |
+
custom_scale = gr.Textbox(
|
1008 |
+
value="",
|
1009 |
+
label="Custom Scale Notes",
|
1010 |
+
placeholder="60,62,65,67,70",
|
1011 |
+
visible=False
|
1012 |
+
)
|
1013 |
+
|
1014 |
+
with gr.Row():
|
1015 |
+
base_tempo = gr.Slider(
|
1016 |
+
120, 960,
|
1017 |
+
value=480,
|
1018 |
+
step=1,
|
1019 |
+
label="Tempo (ticks per beat)"
|
1020 |
+
)
|
1021 |
+
|
1022 |
+
num_layers = gr.Slider(
|
1023 |
+
1, 6,
|
1024 |
+
value=6,
|
1025 |
+
step=1,
|
1026 |
+
label="Max Layers"
|
1027 |
+
)
|
1028 |
+
|
1029 |
+
with gr.Row():
|
1030 |
+
velocity_low = gr.Slider(
|
1031 |
+
1, 126,
|
1032 |
+
value=40,
|
1033 |
+
step=1,
|
1034 |
+
label="Min Velocity"
|
1035 |
+
)
|
1036 |
+
|
1037 |
+
velocity_high = gr.Slider(
|
1038 |
+
2, 127,
|
1039 |
+
value=90,
|
1040 |
+
step=1,
|
1041 |
+
label="Max Velocity"
|
1042 |
+
)
|
1043 |
+
|
1044 |
+
seed = gr.Number(
|
1045 |
+
value=42,
|
1046 |
+
precision=0,
|
1047 |
+
label="Random Seed"
|
1048 |
+
)
|
1049 |
+
|
1050 |
+
generate_btn = gr.Button(
|
1051 |
+
"🎼 Generate MIDI",
|
1052 |
+
variant="primary",
|
1053 |
+
size="lg"
|
1054 |
+
)
|
1055 |
+
|
1056 |
+
with gr.Column(scale=1):
|
1057 |
+
midi_output = gr.File(
|
1058 |
+
label="Generated MIDI File",
|
1059 |
+
file_types=[".mid", ".midi"]
|
1060 |
+
)
|
1061 |
+
|
1062 |
+
stats_display = gr.Markdown(label="Quick Stats")
|
1063 |
+
|
1064 |
+
metadata_json = gr.Code(
|
1065 |
+
label="Metadata (JSON)",
|
1066 |
+
language="json"
|
1067 |
+
)
|
1068 |
+
|
1069 |
+
with gr.Row():
|
1070 |
+
play_instructions = gr.Markdown(
|
1071 |
+
"""
|
1072 |
+
### 🎧 How to Play
|
1073 |
+
1. Download the MIDI file
|
1074 |
+
2. Open in any DAW or MIDI player
|
1075 |
+
3. Adjust instruments and effects as desired
|
1076 |
+
4. Export to audio format
|
1077 |
+
"""
|
1078 |
+
)
|
1079 |
+
|
1080 |
+
# Set up interactions
|
1081 |
+
def update_custom_scale_visibility(choice):
|
1082 |
+
return gr.update(visible=(choice == "Custom"))
|
1083 |
+
|
1084 |
+
scale_choice.change(
|
1085 |
+
update_custom_scale_visibility,
|
1086 |
+
inputs=[scale_choice],
|
1087 |
+
outputs=[custom_scale]
|
1088 |
+
)
|
1089 |
+
|
1090 |
+
def generate_wrapper(
|
1091 |
+
text, model_name, compute_mode, base_tempo,
|
1092 |
+
velocity_low, velocity_high, scale_choice,
|
1093 |
+
custom_scale, num_layers, instrument_preset, seed
|
1094 |
+
):
|
1095 |
+
"""Wrapper for generation with error handling."""
|
1096 |
+
try:
|
1097 |
+
# Parse custom scale if needed
|
1098 |
+
custom_notes = None
|
1099 |
+
if scale_choice == "Custom" and custom_scale:
|
1100 |
+
custom_notes = [int(x.strip()) for x in custom_scale.split(",")]
|
1101 |
+
|
1102 |
+
# Generate
|
1103 |
+
filename, metadata = self.orchestra.generate(
|
1104 |
+
text=text,
|
1105 |
+
model_name=model_name,
|
1106 |
+
compute_mode=compute_mode,
|
1107 |
+
base_tempo=int(base_tempo),
|
1108 |
+
velocity_range=(int(velocity_low), int(velocity_high)),
|
1109 |
+
scale_name=scale_choice,
|
1110 |
+
custom_scale_notes=custom_notes,
|
1111 |
+
num_layers=int(num_layers),
|
1112 |
+
instrument_preset=instrument_preset,
|
1113 |
+
seed=int(seed)
|
1114 |
+
)
|
1115 |
+
|
1116 |
+
# Format stats
|
1117 |
+
stats = metadata.get("stats", {})
|
1118 |
+
stats_text = f"""
|
1119 |
+
### Generation Statistics
|
1120 |
+
- **Layers Used**: {stats.get('num_layers', 'N/A')}
|
1121 |
+
- **Tokens Processed**: {stats.get('num_tokens', 'N/A')}
|
1122 |
+
- **Total Notes**: {stats.get('total_notes', 'N/A')}
|
1123 |
+
- **Notes per Layer**: {stats.get('notes_per_layer', [])}
|
1124 |
+
- **Scale**: {stats.get('scale', [])}
|
1125 |
+
- **Tempo**: {stats.get('tempo_ticks_per_beat', 'N/A')} ticks/beat
|
1126 |
+
"""
|
1127 |
+
|
1128 |
+
return filename, stats_text, json.dumps(metadata, indent=2)
|
1129 |
+
|
1130 |
+
except Exception as e:
|
1131 |
+
error_msg = f"### ❌ Error\n{str(e)}"
|
1132 |
+
return None, error_msg, json.dumps({"error": str(e)}, indent=2)
|
1133 |
+
|
1134 |
+
generate_btn.click(
|
1135 |
+
fn=generate_wrapper,
|
1136 |
+
inputs=[
|
1137 |
+
text_input, model_name, compute_mode, base_tempo,
|
1138 |
+
velocity_low, velocity_high, scale_choice,
|
1139 |
+
custom_scale, num_layers, instrument_preset, seed
|
1140 |
+
],
|
1141 |
+
outputs=[midi_output, stats_display, metadata_json]
|
1142 |
+
)
|
1143 |
+
|
1144 |
+
|
1145 |
+
# Main Entry Point
|
1146 |
+
|
1147 |
+
def main():
|
1148 |
+
"""Main entry point for the application."""
|
1149 |
+
# Initialize orchestra
|
1150 |
+
orchestra = LLMForestOrchestra()
|
1151 |
+
|
1152 |
+
# Create interface
|
1153 |
+
interface = GradioInterface(orchestra)
|
1154 |
+
demo = interface.create_interface()
|
1155 |
+
|
1156 |
+
# Launch
|
1157 |
+
demo.launch()
|
1158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1159 |
|
1160 |
if __name__ == "__main__":
|
1161 |
+
main()
|