Locutusque commited on
Commit
5048db9
·
verified ·
1 Parent(s): d61ceaa

Big refactor + more features

Browse files
Files changed (1) hide show
  1. app.py +1121 -363
app.py CHANGED
@@ -4,11 +4,13 @@ import gc
4
  import math
5
  import time
6
  import uuid
 
7
  import spaces
8
  import random
9
- from dataclasses import dataclass
10
- from typing import Dict, List, Tuple, Optional
11
- import json
 
12
 
13
  import gradio as gr
14
  import numpy as np
@@ -18,386 +20,1142 @@ from transformers import AutoModel, AutoTokenizer
18
  import mido
19
  from mido import Message, MidiFile, MidiTrack
20
 
21
- # -------------------- Defaults & Music Helpers --------------------
22
-
23
- DEFAULT_MODEL = "unsloth/Qwen3-14B-Base"
24
-
25
- SCALES = {
26
- "C pentatonic": [60, 62, 65, 67, 70, 72, 74, 77],
27
- "C major": [60, 62, 64, 65, 67, 69, 71, 72],
28
- "A minor": [57, 59, 60, 62, 64, 65, 67, 69],
29
- "Custom (comma-separated MIDI notes)": [],
30
- }
31
-
32
- LAYER_INSTRUMENT_PRESETS = {
33
- "Ensemble (melody+bass+pad etc.)": {
34
- 0: (0, 'melody'),
35
- 1: (33, 'bass'),
36
- 2: (46, 'harmony'),
37
- 3: (48, 'pad'),
38
- 4: (11, 'accent'),
39
- 5: (89, 'atmosphere'),
40
- },
41
- "Piano Trio (melody+bass+harmony)": {
42
- 0: (0, 'melody'),
43
- 1: (33, 'bass'),
44
- 2: (0, 'harmony'),
45
- 3: (48, 'pad'),
46
- 4: (0, 'accent'),
47
- 5: (0, 'atmosphere'),
48
- },
49
- "Pads & Atmos": {
50
- 0: (48, 'pad'),
51
- 1: (48, 'pad'),
52
- 2: (89, 'atmosphere'),
53
- 3: (89, 'atmosphere'),
54
- 4: (46, 'harmony'),
55
- 5: (11, 'accent'),
56
- },
57
- }
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  @dataclass
60
- class GenConfig:
 
61
  model_name: str
62
- compute_mode: str # "Full model" or "Mock latents"
63
  base_tempo: int
64
  velocity_range: Tuple[int, int]
65
- scale: List[int]
66
  num_layers_limit: int
67
  seed: int
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- # --- Core math helpers ---
70
-
71
- def entropy(p: np.ndarray) -> float:
72
- p = p / (p.sum() + 1e-9)
73
- return float(-np.sum(p * np.log2(p + 1e-9)))
74
-
75
- def quantize_time(time_val: int, grid: int = 120) -> int:
76
- return int(round(time_val / grid) * grid)
77
-
78
- def norm_to_scale(val: float, scale: np.ndarray, octave_range: int = 2) -> int:
79
- octave = int(abs(val) * octave_range) * 12
80
- note_idx = int(abs(val * 100) % len(scale))
81
- return int(scale[note_idx] + octave)
82
-
83
- ROLE_FREQS = {
84
- 'melody': 2.0,
85
- 'bass': 0.5,
86
- 'harmony': 1.5,
87
- 'pad': 0.25,
88
- 'accent': 3.0,
89
- 'atmosphere': 0.33
90
- }
91
-
92
- ROLE_WEIGHTS = {
93
- 'melody': np.array([0.4, 0.2, 0.2, 0.1, 0.1]),
94
- 'bass': np.array([0.1, 0.4, 0.1, 0.3, 0.1]),
95
- 'harmony': np.array([0.2, 0.2, 0.3, 0.2, 0.1]),
96
- 'pad': np.array([0.1, 0.3, 0.1, 0.1, 0.4]),
97
- 'accent': np.array([0.5, 0.1, 0.2, 0.1, 0.1]),
98
- 'atmosphere': np.array([0.1, 0.2, 0.1, 0.2, 0.4])
99
- }
100
-
101
- def create_note_probability(layer_idx, token_idx, attention_val, hidden_state, num_tokens, role: str):
102
- base_prob = 1 / (1 + np.exp(-10 * (attention_val - 0.5)))
103
- temporal_factor = 0.5 + 0.5 * np.sin(2 * np.pi * ROLE_FREQS[role] * token_idx / max(1, num_tokens))
104
- energy = np.linalg.norm(hidden_state)
105
- energy_factor = np.tanh(energy / 10)
106
- local_variance = np.var(hidden_state)
107
- variance_factor = 1 - np.exp(-local_variance)
108
- state_entropy = entropy(np.abs(hidden_state))
109
- max_entropy = np.log2(max(2, hidden_state.shape[0]))
110
- entropy_factor = state_entropy / max_entropy
111
- factors = np.array([base_prob, temporal_factor, energy_factor, variance_factor, entropy_factor])
112
- weights = ROLE_WEIGHTS[role]
113
- combined_prob = float(np.dot(weights, factors))
114
- noise_seed = layer_idx * 1000 + token_idx
115
- noise = 0.1 * (np.sin(noise_seed * 0.1) + np.cos(noise_seed * 0.23)) / 2
116
- final_prob = (combined_prob + noise) ** 1.5
117
- return float(np.clip(final_prob, 0, 1))
118
-
119
- def should_play_note_stochastic(layer_idx, token_idx, attention_val, hidden_state, num_tokens, role: str, history: Dict[int,int]):
120
- prob = create_note_probability(layer_idx, token_idx, attention_val, hidden_state, num_tokens, role)
121
- if layer_idx in history:
122
- last_played = history[layer_idx]
123
- silence_duration = token_idx - last_played
124
- prob *= (1 + np.tanh(silence_duration / 5) * 0.5)
125
- play_note = np.random.random() < prob
126
- if play_note:
127
- history[layer_idx] = token_idx
128
- return play_note
129
-
130
- # -------------------- Model / Latents --------------------
131
 
132
  @dataclass
133
  class Latents:
 
134
  hidden_states: List[torch.Tensor]
135
  attentions: List[torch.Tensor]
136
  num_layers: int
137
  num_tokens: int
138
-
139
- @spaces.GPU(duration=45)
140
- def get_latents(text: str, model_name: str, compute_mode: str, max_layers: int, progress=gr.Progress(track_tqdm=True)) -> Latents:
141
- if compute_mode == "Mock latents":
142
- # Fast path for Spaces without big GPUs
143
- tokens = max(16, min(128, len(text.split()) * 4))
144
- layers = min(max_layers, 6)
145
- hidden_states = [torch.randn(1, tokens, 128) for _ in range(layers)]
146
- attentions = [torch.rand(1, 8, tokens, tokens) for _ in range(layers)]
147
- return Latents(hidden_states=hidden_states, attentions=attentions, num_layers=layers, num_tokens=tokens)
148
-
149
- # Full model path
150
- tokenizer = AutoTokenizer.from_pretrained(model_name)
151
- if tokenizer.pad_token is None and tokenizer.eos_token is not None:
152
- tokenizer.pad_token = tokenizer.eos_token
153
-
154
- # Try different memory-friendly loading strategies
155
- load_kwargs = dict(
156
- output_hidden_states=True,
157
- output_attentions=True,
158
- device_map="cuda",
159
- )
160
-
161
- # dtype heuristics
162
- try:
163
- load_kwargs["torch_dtype"] = torch.bfloat16 if torch.cuda.is_available() else torch.float32
164
- except Exception:
165
- pass
166
-
167
- model = AutoModel.from_pretrained(model_name, **load_kwargs)
168
-
169
- inputs = tokenizer(text, return_tensors="pt")
170
- device = next(model.parameters()).device
171
- inputs = {k: v.to(device) for k, v in inputs.items()}
172
-
173
- with torch.no_grad():
174
- outputs = model(**inputs)
175
- hidden_states = list(outputs.hidden_states)
176
- attentions = list(outputs.attentions)
177
-
178
- # Move to CPU numpy-friendly dtype to free VRAM
179
- hidden_states = [hs.to("cpu") for hs in hidden_states]
180
- attentions = [att.to("cpu") for att in attentions]
181
-
182
- # Trim layers
183
- layers = min(max_layers, 6, len(hidden_states))
184
- tokens = hidden_states[0].shape[1]
185
-
186
- # Clean up VRAM
187
- try:
188
- del model
189
- if torch.cuda.is_available():
190
- torch.cuda.empty_cache()
191
- except Exception:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  pass
193
 
194
- return Latents(hidden_states=hidden_states[:layers], attentions=attentions[:layers], num_layers=layers, num_tokens=tokens)
195
-
196
- # -------------------- MIDI Rendering --------------------
197
-
198
- def render_midi(latents: Latents, scale_notes: List[int], base_tempo: int, velocity_range: Tuple[int, int], preset_name: str, seed: int) -> Tuple[bytes, Dict]:
199
- np.random.seed(seed)
200
- random.seed(seed)
201
-
202
- scale = np.array(scale_notes, dtype=int)
203
- num_layers = latents.num_layers
204
- num_tokens = latents.num_tokens
205
- hidden_states = [hs.float().numpy() if isinstance(hs, torch.Tensor) else hs for hs in latents.hidden_states]
206
- attentions = [att.float().numpy() if isinstance(att, torch.Tensor) else att for att in latents.attentions]
207
-
208
- layer_instruments = LAYER_INSTRUMENT_PRESETS[preset_name]
209
-
210
- mid = MidiFile()
211
- tracks: List[MidiTrack] = []
212
- for ch in range(num_layers):
213
- track = MidiTrack()
214
- mid.tracks.append(track)
215
- tracks.append(track)
216
- instrument = layer_instruments.get(ch, (0, 'melody'))[0]
217
- track.append(Message('program_change', program=int(instrument), time=0, channel=ch))
218
-
219
- history: Dict[int, int] = {}
220
- current_time = [0] * num_layers
221
- notes_count = [0] * num_layers
222
-
223
- for token_idx in range(num_tokens):
224
- if token_idx > 0 and token_idx % 4 == 0:
225
- for layer_idx in range(num_layers):
226
- current_time[layer_idx] += base_tempo
227
-
228
- pan = 64 + int(32 * np.sin(token_idx * math.pi / max(1, num_tokens)))
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  for layer_idx in range(num_layers):
231
- role = layer_instruments.get(layer_idx, (0, 'melody'))[1]
232
-
233
- attn_matrix = attentions[min(layer_idx, len(attentions) - 1)][0, :, token_idx, :]
234
- attention_strength = float(np.mean(attn_matrix))
235
- layer_vec = hidden_states[layer_idx][0, token_idx]
236
-
237
- if not should_play_note_stochastic(layer_idx, token_idx, attention_strength, layer_vec, num_tokens, role, history):
238
- continue
239
-
240
- if role == 'melody':
241
- note = norm_to_scale(layer_vec[0], scale, octave_range=1)
242
- notes_to_play = [note]
243
- elif role == 'bass':
244
- note = norm_to_scale(layer_vec[0], scale, octave_range=0) - 12
245
- notes_to_play = [note]
246
- elif role == 'harmony':
247
- notes_to_play = [norm_to_scale(layer_vec[i], scale, octave_range=1) for i in range(0, min(2, len(layer_vec)), 1)]
248
- elif role == 'pad':
249
- notes_to_play = [norm_to_scale(layer_vec[i], scale, octave_range=1) for i in range(0, min(3, len(layer_vec)), 2)]
250
- elif role == 'accent':
251
- note = norm_to_scale(layer_vec[0], scale, octave_range=2) + 12
252
- notes_to_play = [note]
253
- else:
254
- notes_to_play = [norm_to_scale(layer_vec[i], scale, octave_range=1) for i in range(0, min(2, len(layer_vec)), 3)]
255
-
256
- base_velocity = int(attention_strength * (velocity_range[1] - velocity_range[0]) + velocity_range[0])
257
- if role == 'melody':
258
- velocity = min(base_velocity + 10, 127)
259
- elif role == 'bass':
260
- velocity = base_velocity
261
- elif role == 'accent':
262
- velocity = min(base_velocity + 20, 127)
263
  else:
264
- velocity = max(base_velocity - 10, 20)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
- if role in ['pad', 'atmosphere']:
267
- duration = base_tempo * 4
268
- elif role == 'bass':
269
- duration = base_tempo
270
- else:
271
- try:
272
- dur_factor = entropy(attn_matrix.mean(axis=0)) / (np.log2(attn_matrix.shape[-1]) + 1e-9)
273
- except Exception:
274
- dur_factor = 0.5
275
- duration = quantize_time(int(base_tempo * (0.5 + dur_factor * 1.5)))
276
-
277
- for note in notes_to_play:
278
- note = max(21, min(108, int(note)))
279
- tracks[layer_idx].append(Message('note_on', note=note, velocity=velocity, time=current_time[layer_idx], channel=layer_idx))
280
- tracks[layer_idx].append(Message('note_off', note=note, velocity=0, time=duration, channel=layer_idx))
281
- current_time[layer_idx] = 0
282
- notes_count[layer_idx] += 1
283
-
284
- if token_idx == 0:
285
- tracks[layer_idx].append(Message('control_change', control=10, value=pan, time=0, channel=layer_idx))
286
-
287
- # Save to bytes
288
- bio = io.BytesIO()
289
- mid.save(file=bio)
290
- bio.seek(0)
291
-
292
- meta = {
293
- "num_layers": num_layers,
294
- "num_tokens": num_tokens,
295
- "notes_per_layer": notes_count,
296
- "total_notes": int(sum(notes_count)),
297
- "tempo_ticks_per_beat": int(base_tempo),
298
- "scale": list(map(int, scale.tolist())),
299
- }
300
- return bio.read(), meta
301
-
302
- # -------------------- Gradio UI --------------------
303
-
304
- DESCRIPTION = """
305
- # LLM Forest Orchestra — Sonify Transformer Internals
306
- Turn hidden states and attentions into a multi-track MIDI composition.
307
-
308
- - **Two compute modes**: *Full model* (loads a HF model and extracts latents) or *Mock latents* (quick demo with synthetic tensors — great for CPU-only Spaces).
309
- - Choose **scale**, **tempo**, **velocity range**, and **instrument/role preset**.
310
- - Exports a **MIDI** you can arrange further in your DAW.
311
-
312
-
313
- ## Inspiration
314
-
315
- This project is inspired by the way **mushrooms and mycelial networks in forests**
316
- connect plants and trees, forming a living web of communication and resource sharing.
317
- These connections, can be turned into ethereal music.
318
- Just as signals move through these hidden connections, transformer models also
319
- pass hidden states and attentions across their layers. Here, those hidden
320
- connections are translated into **music**, analogous to the forest's secret orchestra.
321
- """
322
-
323
- EXAMPLE_TEXT = """Joy cascades in golden waterfalls, crashing into pools of melancholy blue.
324
- Anger burns red through veins of marble, while serenity floats on clouds of softest grey.
325
- Love pulses in waves of crimson and rose, intertwining with longing's purple haze.
326
- Each feeling resonates at its own frequency, painting music across the soul's canvas.
327
- """
328
-
329
- def parse_scale(selection: str, custom: str) -> List[int]:
330
- if selection == "Custom (comma-separated MIDI notes)":
331
- try:
332
- return [int(x.strip()) for x in custom.split(",") if x.strip()]
333
- except Exception:
334
- return SCALES["C pentatonic"]
335
- return SCALES[selection] if SCALES[selection] else SCALES["C pentatonic"]
336
-
337
- def generate(text, model_name, compute_mode, base_tempo, velocity_low, velocity_high, scale_choice, custom_scale, num_layers, preset, seed):
338
- scale = parse_scale(scale_choice, custom_scale)
339
- cfg = GenConfig(
340
- model_name=model_name or DEFAULT_MODEL,
341
- compute_mode=compute_mode,
342
- base_tempo=int(base_tempo),
343
- velocity_range=(int(velocity_low), int(velocity_high)),
344
- scale=scale,
345
- num_layers_limit=int(num_layers),
346
- seed=int(seed),
347
- )
348
-
349
- # Get latents
350
- latents = get_latents(text, cfg.model_name, cfg.compute_mode, cfg.num_layers_limit)
351
-
352
- # Render MIDI
353
- midi_bytes, meta = render_midi(latents, cfg.scale, cfg.base_tempo, cfg.velocity_range, preset, cfg.seed)
354
-
355
- # Persist to a file for download
356
- out_name = f"llm_forest_orchestra_{uuid.uuid4().hex[:8]}.mid"
357
- with open(out_name, "wb") as f:
358
- f.write(midi_bytes)
359
-
360
- # Prepare quick stats
361
- stats = (
362
- f"Layers: {meta['num_layers']} | Tokens: {meta['num_tokens']} | "
363
- f"Total notes: {meta['total_notes']} | Scale: {meta['scale']} | "
364
- f"Tempo (ticks/beat): {meta['tempo_ticks_per_beat']}"
365
- )
366
-
367
- return out_name, stats, json.dumps(meta, indent=2)
368
-
369
- with gr.Blocks(title="LLM Forest Orchestra — MIDI from Transformer Internals") as demo:
370
- gr.Markdown(DESCRIPTION)
371
-
372
- with gr.Row():
373
- with gr.Column():
374
- text = gr.Textbox(value=EXAMPLE_TEXT, label="Input text", lines=8)
375
- model_name = gr.Textbox(value=DEFAULT_MODEL, label="HF model (base) to probe", info="Should support output_hidden_states & output_attentions")
376
- compute_mode = gr.Radio(choices=["Mock latents", "Full model"], value="Full model", label="Compute mode")
377
- preset = gr.Dropdown(choices=list(LAYER_INSTRUMENT_PRESETS.keys()), value="Ensemble (melody+bass+pad etc.)", label="Instrument/Role preset")
378
- with gr.Row():
379
- base_tempo = gr.Slider(120, 960, value=480, step=1, label="Ticks per beat (tempo grid)")
380
- num_layers = gr.Slider(1, 6, value=6, step=1, label="Max layers to use")
381
- with gr.Row():
382
- velocity_low = gr.Slider(1, 126, value=40, step=1, label="Velocity min")
383
- velocity_high = gr.Slider(2, 127, value=90, step=1, label="Velocity max")
384
- with gr.Row():
385
- scale_choice = gr.Dropdown(choices=list(SCALES.keys()), value="C pentatonic", label="Scale")
386
- custom_scale = gr.Textbox(value="", label="Custom scale notes (e.g. 60,62,65,67)")
387
- seed = gr.Number(value=42, precision=0, label="Random seed")
388
-
389
- btn = gr.Button("Generate MIDI", variant="primary")
390
-
391
- with gr.Column():
392
- midi_file = gr.File(label="MIDI output (.mid)")
393
- stats = gr.Markdown("")
394
- meta_json = gr.Code(label="Meta (JSON)")
395
-
396
- btn.click(
397
- fn=generate,
398
- inputs=[text, model_name, compute_mode, base_tempo, velocity_low, velocity_high, scale_choice, custom_scale, num_layers, preset, seed],
399
- outputs=[midi_file, stats, meta_json]
400
- )
401
 
402
  if __name__ == "__main__":
403
- demo.launch()
 
4
  import math
5
  import time
6
  import uuid
7
+ import json
8
  import spaces
9
  import random
10
+ from abc import ABC, abstractmethod
11
+ from dataclasses import dataclass, field, asdict
12
+ from typing import Dict, List, Tuple, Optional, Any, Union
13
+ from enum import Enum
14
 
15
  import gradio as gr
16
  import numpy as np
 
20
  import mido
21
  from mido import Message, MidiFile, MidiTrack
22
 
23
+
24
+ # Configuration Classes
25
+
26
+ class ComputeMode(Enum):
27
+ """Enum for computation modes."""
28
+ FULL_MODEL = "Full model"
29
+ MOCK_LATENTS = "Mock latents"
30
+
31
+
32
+ class MusicRole(Enum):
33
+ """Enum for musical roles/layers."""
34
+ MELODY = "melody"
35
+ BASS = "bass"
36
+ HARMONY = "harmony"
37
+ PAD = "pad"
38
+ ACCENT = "accent"
39
+ ATMOSPHERE = "atmosphere"
40
+
41
+
42
+ @dataclass
43
+ class ScaleDefinition:
44
+ """Represents a musical scale."""
45
+ name: str
46
+ notes: List[int]
47
+ description: str = ""
48
+
49
+ def __post_init__(self):
50
+ """Validate scale notes are within MIDI range."""
51
+ for note in self.notes:
52
+ if not 0 <= note <= 127:
53
+ raise ValueError(f"MIDI note {note} out of range (0-127)")
54
+
55
+
56
+ @dataclass
57
+ class InstrumentMapping:
58
+ """Maps a layer to an instrument and musical role."""
59
+ program: int # MIDI program number
60
+ role: MusicRole
61
+ channel: int
62
+ name: str = ""
63
+
64
+ def __post_init__(self):
65
+ """Validate MIDI program and channel."""
66
+ if not 0 <= self.program <= 127:
67
+ raise ValueError(f"MIDI program {self.program} out of range")
68
+ if not 0 <= self.channel <= 15:
69
+ raise ValueError(f"MIDI channel {self.channel} out of range")
70
+
71
 
72
  @dataclass
73
+ class GenerationConfig:
74
+ """Complete configuration for music generation."""
75
  model_name: str
76
+ compute_mode: ComputeMode
77
  base_tempo: int
78
  velocity_range: Tuple[int, int]
79
+ scale: ScaleDefinition
80
  num_layers_limit: int
81
  seed: int
82
+ instrument_preset: str
83
+
84
+ # Additional configuration options
85
+ quantization_grid: int = 120
86
+ octave_range: int = 2
87
+ dynamics_curve: str = "linear" # linear, exponential, logarithmic
88
+
89
+ def validate(self):
90
+ """Validate configuration parameters."""
91
+ if not 1 <= self.base_tempo <= 2000:
92
+ raise ValueError("Tempo must be between 1 and 2000")
93
+ if not 1 <= self.velocity_range[0] < self.velocity_range[1] <= 127:
94
+ raise ValueError("Invalid velocity range")
95
+ if not 1 <= self.num_layers_limit <= 32:
96
+ raise ValueError("Number of layers must be between 1 and 32")
97
+
98
+ def to_dict(self) -> Dict:
99
+ """Convert config to dictionary for serialization."""
100
+ return {
101
+ "model_name": self.model_name,
102
+ "compute_mode": self.compute_mode.value,
103
+ "base_tempo": self.base_tempo,
104
+ "velocity_range": self.velocity_range,
105
+ "scale_name": self.scale.name,
106
+ "scale_notes": self.scale.notes,
107
+ "num_layers_limit": self.num_layers_limit,
108
+ "seed": self.seed,
109
+ "instrument_preset": self.instrument_preset,
110
+ "quantization_grid": self.quantization_grid,
111
+ "octave_range": self.octave_range,
112
+ "dynamics_curve": self.dynamics_curve
113
+ }
114
+
115
+ @classmethod
116
+ def from_dict(cls, data: Dict, scale_manager: "ScaleManager") -> "GenerationConfig":
117
+ """Create config from dictionary."""
118
+ scale = scale_manager.get_scale(data["scale_name"])
119
+ if scale is None:
120
+ scale = ScaleDefinition(name="Custom", notes=data["scale_notes"])
121
+
122
+ return cls(
123
+ model_name=data["model_name"],
124
+ compute_mode=ComputeMode(data["compute_mode"]),
125
+ base_tempo=data["base_tempo"],
126
+ velocity_range=tuple(data["velocity_range"]),
127
+ scale=scale,
128
+ num_layers_limit=data["num_layers_limit"],
129
+ seed=data["seed"],
130
+ instrument_preset=data["instrument_preset"],
131
+ quantization_grid=data.get("quantization_grid", 120),
132
+ octave_range=data.get("octave_range", 2),
133
+ dynamics_curve=data.get("dynamics_curve", "linear")
134
+ )
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  @dataclass
138
  class Latents:
139
+ """Container for model latents."""
140
  hidden_states: List[torch.Tensor]
141
  attentions: List[torch.Tensor]
142
  num_layers: int
143
  num_tokens: int
144
+ metadata: Dict[str, Any] = field(default_factory=dict)
145
+
146
+
147
+ # Music Components
148
+
149
+ class ScaleManager:
150
+ """Manages musical scales and modes."""
151
+
152
+ def __init__(self):
153
+ """Initialize with default scales."""
154
+ self.scales = {
155
+ "C pentatonic": ScaleDefinition(
156
+ "C pentatonic",
157
+ [60, 62, 65, 67, 70, 72, 74, 77],
158
+ "Major pentatonic scale"
159
+ ),
160
+ "C major": ScaleDefinition(
161
+ "C major",
162
+ [60, 62, 64, 65, 67, 69, 71, 72],
163
+ "Major scale (Ionian mode)"
164
+ ),
165
+ "A minor": ScaleDefinition(
166
+ "A minor",
167
+ [57, 59, 60, 62, 64, 65, 67, 69],
168
+ "Natural minor scale (Aeolian mode)"
169
+ ),
170
+ "D dorian": ScaleDefinition(
171
+ "D dorian",
172
+ [62, 64, 65, 67, 69, 71, 72, 74],
173
+ "Dorian mode - minor with raised 6th"
174
+ ),
175
+ "E phrygian": ScaleDefinition(
176
+ "E phrygian",
177
+ [64, 65, 67, 69, 71, 72, 74, 76],
178
+ "Phrygian mode - minor with lowered 2nd"
179
+ ),
180
+ "G mixolydian": ScaleDefinition(
181
+ "G mixolydian",
182
+ [67, 69, 71, 72, 74, 76, 77, 79],
183
+ "Mixolydian mode - major with lowered 7th"
184
+ ),
185
+ "Blues scale": ScaleDefinition(
186
+ "Blues scale",
187
+ [60, 63, 65, 66, 67, 70, 72, 75],
188
+ "Blues scale with blue notes"
189
+ ),
190
+ "Chromatic": ScaleDefinition(
191
+ "Chromatic",
192
+ list(range(60, 72)),
193
+ "All 12 semitones"
194
+ )
195
+ }
196
+
197
+ def get_scale(self, name: str) -> Optional[ScaleDefinition]:
198
+ """Get scale by name."""
199
+ return self.scales.get(name)
200
+
201
+ def add_custom_scale(self, name: str, notes: List[int], description: str = "") -> ScaleDefinition:
202
+ """Add a custom scale."""
203
+ scale = ScaleDefinition(name, notes, description)
204
+ self.scales[name] = scale
205
+ return scale
206
+
207
+ def list_scales(self) -> List[str]:
208
+ """Get list of available scale names."""
209
+ return list(self.scales.keys())
210
+
211
+
212
+ class InstrumentPresetManager:
213
+ """Manages instrument presets for different musical styles."""
214
+
215
+ def __init__(self):
216
+ """Initialize with default presets."""
217
+ self.presets = {
218
+ "Ensemble (melody+bass+pad etc.)": [
219
+ InstrumentMapping(0, MusicRole.MELODY, 0, "Piano"),
220
+ InstrumentMapping(33, MusicRole.BASS, 1, "Electric Bass"),
221
+ InstrumentMapping(46, MusicRole.HARMONY, 2, "Harp"),
222
+ InstrumentMapping(48, MusicRole.PAD, 3, "String Ensemble"),
223
+ InstrumentMapping(11, MusicRole.ACCENT, 4, "Vibraphone"),
224
+ InstrumentMapping(89, MusicRole.ATMOSPHERE, 5, "Pad Warm")
225
+ ],
226
+ "Piano Trio (melody+bass+harmony)": [
227
+ InstrumentMapping(0, MusicRole.MELODY, 0, "Piano"),
228
+ InstrumentMapping(33, MusicRole.BASS, 1, "Electric Bass"),
229
+ InstrumentMapping(0, MusicRole.HARMONY, 2, "Piano"),
230
+ InstrumentMapping(48, MusicRole.PAD, 3, "String Ensemble"),
231
+ InstrumentMapping(0, MusicRole.ACCENT, 4, "Piano"),
232
+ InstrumentMapping(0, MusicRole.ATMOSPHERE, 5, "Piano")
233
+ ],
234
+ "Pads & Atmosphere": [
235
+ InstrumentMapping(48, MusicRole.PAD, 0, "String Ensemble"),
236
+ InstrumentMapping(48, MusicRole.PAD, 1, "String Ensemble"),
237
+ InstrumentMapping(89, MusicRole.ATMOSPHERE, 2, "Pad Warm"),
238
+ InstrumentMapping(89, MusicRole.ATMOSPHERE, 3, "Pad Warm"),
239
+ InstrumentMapping(46, MusicRole.HARMONY, 4, "Harp"),
240
+ InstrumentMapping(11, MusicRole.ACCENT, 5, "Vibraphone")
241
+ ],
242
+ "Orchestral": [
243
+ InstrumentMapping(40, MusicRole.MELODY, 0, "Violin"),
244
+ InstrumentMapping(42, MusicRole.BASS, 1, "Cello"),
245
+ InstrumentMapping(46, MusicRole.HARMONY, 2, "Harp"),
246
+ InstrumentMapping(48, MusicRole.PAD, 3, "String Ensemble"),
247
+ InstrumentMapping(73, MusicRole.ACCENT, 4, "Flute"),
248
+ InstrumentMapping(49, MusicRole.ATMOSPHERE, 5, "Slow Strings")
249
+ ],
250
+ "Electronic": [
251
+ InstrumentMapping(80, MusicRole.MELODY, 0, "Lead Square"),
252
+ InstrumentMapping(38, MusicRole.BASS, 1, "Synth Bass"),
253
+ InstrumentMapping(81, MusicRole.HARMONY, 2, "Lead Sawtooth"),
254
+ InstrumentMapping(90, MusicRole.PAD, 3, "Pad Polysynth"),
255
+ InstrumentMapping(82, MusicRole.ACCENT, 4, "Lead Calliope"),
256
+ InstrumentMapping(91, MusicRole.ATMOSPHERE, 5, "Pad Bowed")
257
+ ]
258
+ }
259
+
260
+ def get_preset(self, name: str) -> List[InstrumentMapping]:
261
+ """Get instrument preset by name."""
262
+ return self.presets.get(name, self.presets["Ensemble (melody+bass+pad etc.)"])
263
+
264
+ def list_presets(self) -> List[str]:
265
+ """Get list of available preset names."""
266
+ return list(self.presets.keys())
267
+
268
+
269
+ # Music Generation Components
270
+
271
+ class MusicMathUtils:
272
+ """Utility class for music-related mathematical operations."""
273
+
274
+ @staticmethod
275
+ def entropy(p: np.ndarray) -> float:
276
+ """Calculate Shannon entropy of a probability distribution."""
277
+ p = p / (p.sum() + 1e-9)
278
+ return float(-np.sum(p * np.log2(p + 1e-9)))
279
+
280
+ @staticmethod
281
+ def quantize_time(time_val: int, grid: int = 120) -> int:
282
+ """Quantize time value to grid."""
283
+ return int(round(time_val / grid) * grid)
284
+
285
+ @staticmethod
286
+ def norm_to_scale(val: float, scale: np.ndarray, octave_range: int = 2) -> int:
287
+ """Map normalized value to scale note with octave range."""
288
+ octave = int(abs(val) * octave_range) * 12
289
+ note_idx = int(abs(val * 100) % len(scale))
290
+ return int(scale[note_idx] + octave)
291
+
292
+ @staticmethod
293
+ def apply_dynamics_curve(value: float, curve_type: str = "linear") -> float:
294
+ """Apply dynamics curve to a value."""
295
+ value = np.clip(value, 0, 1)
296
+ if curve_type == "exponential":
297
+ return value ** 2
298
+ elif curve_type == "logarithmic":
299
+ return np.log1p(value * np.e) / np.log1p(np.e)
300
+ else: # linear
301
+ return value
302
+
303
+
304
+ class NoteGenerator:
305
+ """Generates notes based on neural network latents."""
306
+
307
+ # Role-specific frequency multipliers
308
+ ROLE_FREQUENCIES = {
309
+ MusicRole.MELODY: 2.0,
310
+ MusicRole.BASS: 0.5,
311
+ MusicRole.HARMONY: 1.5,
312
+ MusicRole.PAD: 0.25,
313
+ MusicRole.ACCENT: 3.0,
314
+ MusicRole.ATMOSPHERE: 0.33
315
+ }
316
+
317
+ # Role-specific weight distributions
318
+ ROLE_WEIGHTS = {
319
+ MusicRole.MELODY: np.array([0.4, 0.2, 0.2, 0.1, 0.1]),
320
+ MusicRole.BASS: np.array([0.1, 0.4, 0.1, 0.3, 0.1]),
321
+ MusicRole.HARMONY: np.array([0.2, 0.2, 0.3, 0.2, 0.1]),
322
+ MusicRole.PAD: np.array([0.1, 0.3, 0.1, 0.1, 0.4]),
323
+ MusicRole.ACCENT: np.array([0.5, 0.1, 0.2, 0.1, 0.1]),
324
+ MusicRole.ATMOSPHERE: np.array([0.1, 0.2, 0.1, 0.2, 0.4])
325
+ }
326
+
327
+ def __init__(self, config: GenerationConfig):
328
+ """Initialize with generation configuration."""
329
+ self.config = config
330
+ self.math_utils = MusicMathUtils()
331
+ self.history: Dict[int, int] = {}
332
+
333
+ def create_note_probability(
334
+ self,
335
+ layer_idx: int,
336
+ token_idx: int,
337
+ attention_val: float,
338
+ hidden_state: np.ndarray,
339
+ num_tokens: int,
340
+ role: MusicRole
341
+ ) -> float:
342
+ """Calculate probability of playing a note based on multiple factors."""
343
+ # Base probability from attention
344
+ base_prob = 1 / (1 + np.exp(-10 * (attention_val - 0.5)))
345
+
346
+ # Temporal factor based on role frequency
347
+ temporal_factor = 0.5 + 0.5 * np.sin(
348
+ 2 * np.pi * self.ROLE_FREQUENCIES[role] * token_idx / max(1, num_tokens)
349
+ )
350
+
351
+ # Energy factor from hidden state norm
352
+ energy = np.linalg.norm(hidden_state)
353
+ energy_factor = np.tanh(energy / 10)
354
+
355
+ # Variance factor
356
+ local_variance = np.var(hidden_state)
357
+ variance_factor = 1 - np.exp(-local_variance)
358
+
359
+ # Entropy factor
360
+ state_entropy = self.math_utils.entropy(np.abs(hidden_state))
361
+ max_entropy = np.log2(max(2, hidden_state.shape[0]))
362
+ entropy_factor = state_entropy / max_entropy
363
+
364
+ # Combine factors with role-specific weights
365
+ factors = np.array([base_prob, temporal_factor, energy_factor, variance_factor, entropy_factor])
366
+ weights = self.ROLE_WEIGHTS[role]
367
+ combined_prob = float(np.dot(weights, factors))
368
+
369
+ # Add deterministic noise for variation
370
+ noise_seed = layer_idx * 1000 + token_idx
371
+ noise = 0.1 * (np.sin(noise_seed * 0.1) + np.cos(noise_seed * 0.23)) / 2
372
+
373
+ # Apply dynamics curve
374
+ final_prob = (combined_prob + noise) ** 1.5
375
+ final_prob = self.math_utils.apply_dynamics_curve(final_prob, self.config.dynamics_curve)
376
+
377
+ return float(np.clip(final_prob, 0, 1))
378
+
379
+ def should_play_note(
380
+ self,
381
+ layer_idx: int,
382
+ token_idx: int,
383
+ attention_val: float,
384
+ hidden_state: np.ndarray,
385
+ num_tokens: int,
386
+ role: MusicRole
387
+ ) -> bool:
388
+ """Determine if a note should be played."""
389
+ prob = self.create_note_probability(
390
+ layer_idx, token_idx, attention_val, hidden_state, num_tokens, role
391
+ )
392
+
393
+ # Adjust probability based on silence duration
394
+ if layer_idx in self.history:
395
+ last_played = self.history[layer_idx]
396
+ silence_duration = token_idx - last_played
397
+ prob *= (1 + np.tanh(silence_duration / 5) * 0.5)
398
+
399
+ # Stochastic decision
400
+ play_note = np.random.random() < prob
401
+
402
+ if play_note:
403
+ self.history[layer_idx] = token_idx
404
+
405
+ return play_note
406
+
407
+ def generate_notes_for_role(
408
+ self,
409
+ role: MusicRole,
410
+ hidden_state: np.ndarray,
411
+ scale: np.ndarray
412
+ ) -> List[int]:
413
+ """Generate notes based on role and hidden state."""
414
+ if role == MusicRole.MELODY:
415
+ note = self.math_utils.norm_to_scale(
416
+ hidden_state[0], scale, octave_range=1
417
+ )
418
+ return [note]
419
+
420
+ elif role == MusicRole.BASS:
421
+ note = self.math_utils.norm_to_scale(
422
+ hidden_state[0], scale, octave_range=0
423
+ ) - 12
424
+ return [note]
425
+
426
+ elif role == MusicRole.HARMONY:
427
+ return [
428
+ self.math_utils.norm_to_scale(hidden_state[i], scale, octave_range=1)
429
+ for i in range(0, min(2, len(hidden_state)), 1)
430
+ ]
431
+
432
+ elif role == MusicRole.PAD:
433
+ return [
434
+ self.math_utils.norm_to_scale(hidden_state[i], scale, octave_range=1)
435
+ for i in range(0, min(3, len(hidden_state)), 2)
436
+ ]
437
+
438
+ elif role == MusicRole.ACCENT:
439
+ note = self.math_utils.norm_to_scale(
440
+ hidden_state[0], scale, octave_range=2
441
+ ) + 12
442
+ return [note]
443
+
444
+ else: # ATMOSPHERE
445
+ return [
446
+ self.math_utils.norm_to_scale(hidden_state[i], scale, octave_range=1)
447
+ for i in range(0, min(2, len(hidden_state)), 3)
448
+ ]
449
+
450
+ def calculate_velocity(
451
+ self,
452
+ role: MusicRole,
453
+ attention_strength: float
454
+ ) -> int:
455
+ """Calculate note velocity based on role and attention."""
456
+ base_velocity = int(
457
+ attention_strength * (self.config.velocity_range[1] - self.config.velocity_range[0])
458
+ + self.config.velocity_range[0]
459
+ )
460
+
461
+ # Role-specific adjustments
462
+ if role == MusicRole.MELODY:
463
+ velocity = min(base_velocity + 10, 127)
464
+ elif role == MusicRole.ACCENT:
465
+ velocity = min(base_velocity + 20, 127)
466
+ elif role in [MusicRole.PAD, MusicRole.ATMOSPHERE]:
467
+ velocity = max(base_velocity - 10, 20)
468
+ else:
469
+ velocity = base_velocity
470
+
471
+ return velocity
472
+
473
+ def calculate_duration(
474
+ self,
475
+ role: MusicRole,
476
+ attention_matrix: np.ndarray
477
+ ) -> int:
478
+ """Calculate note duration based on role and attention."""
479
+ if role in [MusicRole.PAD, MusicRole.ATMOSPHERE]:
480
+ duration = self.config.base_tempo * 4
481
+ elif role == MusicRole.BASS:
482
+ duration = self.config.base_tempo
483
+ else:
484
+ try:
485
+ dur_factor = self.math_utils.entropy(attention_matrix.mean(axis=0)) / (
486
+ np.log2(attention_matrix.shape[-1]) + 1e-9
487
+ )
488
+ except Exception:
489
+ dur_factor = 0.5
490
+ duration = self.math_utils.quantize_time(
491
+ int(self.config.base_tempo * (0.5 + dur_factor * 1.5)),
492
+ self.config.quantization_grid
493
+ )
494
+
495
+ return duration
496
+
497
+
498
+ # Model Interaction
499
+
500
+ class LatentExtractor(ABC):
501
+ """Abstract base class for latent extraction strategies."""
502
+
503
+ @abstractmethod
504
+ def extract(self, text: str, config: GenerationConfig, progress=None) -> Latents:
505
+ """Extract latents from text."""
506
  pass
507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
 
509
+ class MockLatentExtractor(LatentExtractor):
510
+ """Generate mock latents for testing without loading models."""
511
+
512
+ def extract(self, text: str, config: GenerationConfig, progress=None) -> Latents:
513
+ """Generate synthetic latents based on text."""
514
+ # Simulate token count based on text length
515
+ tokens = max(16, min(128, len(text.split()) * 4))
516
+ layers = min(config.num_layers_limit, 6)
517
+
518
+ # Generate deterministic but varied latents based on text
519
+ np.random.seed(hash(text) % 2**32)
520
+
521
+ hidden_states = [
522
+ torch.randn(1, tokens, 128) for _ in range(layers)
523
+ ]
524
+ attentions = [
525
+ torch.rand(1, 8, tokens, tokens) for _ in range(layers)
526
+ ]
527
+
528
+ metadata = {
529
+ "mode": "mock",
530
+ "text_length": len(text),
531
+ "generated_tokens": tokens,
532
+ "generated_layers": layers
533
+ }
534
+
535
+ return Latents(
536
+ hidden_states=hidden_states,
537
+ attentions=attentions,
538
+ num_layers=layers,
539
+ num_tokens=tokens,
540
+ metadata=metadata
541
+ )
542
+
543
+
544
+ class ModelLatentExtractor(LatentExtractor):
545
+ """Extract real latents from transformer models."""
546
+
547
+ @spaces.GPU(duration=45)
548
+ def extract(self, text: str, config: GenerationConfig, progress=None) -> Latents:
549
+ """Extract latents from a real transformer model."""
550
+ model_name = config.model_name
551
+
552
+ # Load tokenizer
553
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
554
+ if tokenizer.pad_token is None and tokenizer.eos_token is not None:
555
+ tokenizer.pad_token = tokenizer.eos_token
556
+
557
+ # Configure model loading
558
+ load_kwargs = {
559
+ "output_hidden_states": True,
560
+ "output_attentions": True,
561
+ "device_map": "cuda" if torch.cuda.is_available() else "cpu",
562
+ }
563
+
564
+ # Set appropriate dtype
565
+ try:
566
+ load_kwargs["torch_dtype"] = (
567
+ torch.bfloat16 if torch.cuda.is_available() else torch.float32
568
+ )
569
+ except Exception:
570
+ pass
571
+
572
+ # Load model
573
+ model = AutoModel.from_pretrained(model_name, **load_kwargs)
574
+
575
+ # Tokenize input
576
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
577
+ device = next(model.parameters()).device
578
+ inputs = {k: v.to(device) for k, v in inputs.items()}
579
+
580
+ # Get model outputs
581
+ with torch.no_grad():
582
+ outputs = model(**inputs)
583
+ hidden_states = list(outputs.hidden_states)
584
+ attentions = list(outputs.attentions)
585
+
586
+ # Move to CPU to free VRAM
587
+ hidden_states = [hs.to("cpu") for hs in hidden_states]
588
+ attentions = [att.to("cpu") for att in attentions]
589
+
590
+ # Limit layers
591
+ layers = min(config.num_layers_limit, len(hidden_states))
592
+ tokens = hidden_states[0].shape[1]
593
+
594
+ # Clean up
595
+ try:
596
+ del model
597
+ if torch.cuda.is_available():
598
+ torch.cuda.empty_cache()
599
+ gc.collect()
600
+ except Exception:
601
+ pass
602
+
603
+ metadata = {
604
+ "mode": "full_model",
605
+ "model_name": model_name,
606
+ "actual_layers": len(hidden_states),
607
+ "used_layers": layers,
608
+ "tokens": tokens
609
+ }
610
+
611
+ return Latents(
612
+ hidden_states=hidden_states[:layers],
613
+ attentions=attentions[:layers],
614
+ num_layers=layers,
615
+ num_tokens=tokens,
616
+ metadata=metadata
617
+ )
618
+
619
+
620
+ class LatentExtractorFactory:
621
+ """Factory for creating appropriate latent extractors."""
622
+
623
+ @staticmethod
624
+ def create(compute_mode: ComputeMode) -> LatentExtractor:
625
+ """Create a latent extractor based on compute mode."""
626
+ if compute_mode == ComputeMode.MOCK_LATENTS:
627
+ return MockLatentExtractor()
628
+ else:
629
+ return ModelLatentExtractor()
630
+
631
+
632
+ # MIDI Generation
633
+
634
+ class MIDIRenderer:
635
+ """Renders MIDI files from latents."""
636
+
637
+ def __init__(self, config: GenerationConfig, instrument_manager: InstrumentPresetManager):
638
+ """Initialize MIDI renderer."""
639
+ self.config = config
640
+ self.instrument_manager = instrument_manager
641
+ self.note_generator = NoteGenerator(config)
642
+ self.math_utils = MusicMathUtils()
643
+
644
+ def render(self, latents: Latents) -> Tuple[bytes, Dict[str, Any]]:
645
+ """Render MIDI from latents."""
646
+ # Set random seeds for reproducibility
647
+ np.random.seed(self.config.seed)
648
+ random.seed(self.config.seed)
649
+ torch.manual_seed(self.config.seed)
650
+
651
+ # Prepare data
652
+ scale = np.array(self.config.scale.notes, dtype=int)
653
+ num_layers = latents.num_layers
654
+ num_tokens = latents.num_tokens
655
+
656
+ # Convert tensors to numpy
657
+ hidden_states = [
658
+ hs.float().numpy() if isinstance(hs, torch.Tensor) else hs
659
+ for hs in latents.hidden_states
660
+ ]
661
+ attentions = [
662
+ att.float().numpy() if isinstance(att, torch.Tensor) else att
663
+ for att in latents.attentions
664
+ ]
665
+
666
+ # Get instrument mappings
667
+ instrument_mappings = self.instrument_manager.get_preset(self.config.instrument_preset)
668
+
669
+ # Create MIDI file and tracks
670
+ midi_file = MidiFile()
671
+ tracks = self._create_tracks(midi_file, num_layers, instrument_mappings)
672
+
673
+ # Generate notes
674
+ stats = self._generate_notes(
675
+ tracks, hidden_states, attentions,
676
+ scale, num_tokens, instrument_mappings
677
+ )
678
+
679
+ # Convert to bytes
680
+ bio = io.BytesIO()
681
+ midi_file.save(file=bio)
682
+ bio.seek(0)
683
+
684
+ # Prepare metadata
685
+ metadata = {
686
+ "config": self.config.to_dict(),
687
+ "latents_info": latents.metadata,
688
+ "stats": stats,
689
+ "timestamp": time.time()
690
+ }
691
+
692
+ return bio.read(), metadata
693
+
694
+ def _create_tracks(
695
+ self,
696
+ midi_file: MidiFile,
697
+ num_layers: int,
698
+ instrument_mappings: List[InstrumentMapping]
699
+ ) -> List[MidiTrack]:
700
+ """Create MIDI tracks with instrument assignments."""
701
+ tracks = []
702
+
703
  for layer_idx in range(num_layers):
704
+ track = MidiTrack()
705
+ midi_file.tracks.append(track)
706
+ tracks.append(track)
707
+
708
+ # Get instrument mapping for this layer
709
+ if layer_idx < len(instrument_mappings):
710
+ mapping = instrument_mappings[layer_idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711
  else:
712
+ # Default to piano if not enough mappings
713
+ mapping = InstrumentMapping(0, MusicRole.MELODY, layer_idx % 16)
714
+
715
+ # Set instrument
716
+ track.append(Message(
717
+ "program_change",
718
+ program=mapping.program,
719
+ time=0,
720
+ channel=mapping.channel
721
+ ))
722
+
723
+ # Add track name
724
+ if mapping.name:
725
+ track.append(mido.MetaMessage(
726
+ "track_name",
727
+ name=f"{mapping.name} - {mapping.role.value}",
728
+ time=0
729
+ ))
730
+
731
+ return tracks
732
+
733
+ def _generate_notes(
734
+ self,
735
+ tracks: List[MidiTrack],
736
+ hidden_states: List[np.ndarray],
737
+ attentions: List[np.ndarray],
738
+ scale: np.ndarray,
739
+ num_tokens: int,
740
+ instrument_mappings: List[InstrumentMapping]
741
+ ) -> Dict[str, Any]:
742
+ """Generate notes for all tracks."""
743
+ current_time = [0] * len(tracks)
744
+ notes_count = [0] * len(tracks)
745
+
746
+ for token_idx in range(num_tokens):
747
+ # Update time periodically
748
+ if token_idx > 0 and token_idx % 4 == 0:
749
+ for layer_idx in range(len(tracks)):
750
+ current_time[layer_idx] += self.config.base_tempo
751
+
752
+ # Calculate panning
753
+ pan = 64 + int(32 * np.sin(token_idx * math.pi / max(1, num_tokens)))
754
+
755
+ # Generate notes for each layer
756
+ for layer_idx in range(len(tracks)):
757
+ if layer_idx >= len(instrument_mappings):
758
+ continue
759
+
760
+ mapping = instrument_mappings[layer_idx]
761
+
762
+ # Get attention and hidden state
763
+ attn_matrix = attentions[min(layer_idx, len(attentions) - 1)][0, :, token_idx, :]
764
+ attention_strength = float(np.mean(attn_matrix))
765
+ layer_vec = hidden_states[layer_idx][0, token_idx]
766
+
767
+ # Check if note should be played
768
+ if not self.note_generator.should_play_note(
769
+ layer_idx, token_idx, attention_strength,
770
+ layer_vec, num_tokens, mapping.role
771
+ ):
772
+ continue
773
+
774
+ # Generate notes
775
+ notes_to_play = self.note_generator.generate_notes_for_role(
776
+ mapping.role, layer_vec, scale
777
+ )
778
+
779
+ # Calculate velocity and duration
780
+ velocity = self.note_generator.calculate_velocity(
781
+ mapping.role, attention_strength
782
+ )
783
+ duration = self.note_generator.calculate_duration(
784
+ mapping.role, attn_matrix
785
+ )
786
+
787
+ # Add notes to track
788
+ for note in notes_to_play:
789
+ note = max(21, min(108, int(note))) # Clamp to piano range
790
+
791
+ tracks[layer_idx].append(Message(
792
+ "note_on",
793
+ note=note,
794
+ velocity=velocity,
795
+ time=current_time[layer_idx],
796
+ channel=mapping.channel
797
+ ))
798
+
799
+ tracks[layer_idx].append(Message(
800
+ "note_off",
801
+ note=note,
802
+ velocity=0,
803
+ time=duration,
804
+ channel=mapping.channel
805
+ ))
806
+
807
+ current_time[layer_idx] = 0
808
+ notes_count[layer_idx] += 1
809
+
810
+ # Set panning on first token
811
+ if token_idx == 0:
812
+ tracks[layer_idx].append(Message(
813
+ "control_change",
814
+ control=10,
815
+ value=pan,
816
+ time=0,
817
+ channel=mapping.channel
818
+ ))
819
+
820
+ return {
821
+ "num_layers": len(tracks),
822
+ "num_tokens": num_tokens,
823
+ "notes_per_layer": notes_count,
824
+ "total_notes": int(sum(notes_count)),
825
+ "tempo_ticks_per_beat": int(self.config.base_tempo),
826
+ "scale": list(map(int, scale.tolist())),
827
+ }
828
+
829
+
830
+ # Main Orchestrator
831
+
832
+ class LLMForestOrchestra:
833
+ """Main orchestrator class that coordinates the entire pipeline."""
834
+
835
+ DEFAULT_MODEL = "unsloth/Qwen3-14B-Base"
836
+
837
+ def __init__(self):
838
+ """Initialize the orchestra."""
839
+ self.scale_manager = ScaleManager()
840
+ self.instrument_manager = InstrumentPresetManager()
841
+ self.saved_configs: Dict[str, GenerationConfig] = {}
842
+
843
+ def generate(
844
+ self,
845
+ text: str,
846
+ model_name: str,
847
+ compute_mode: str,
848
+ base_tempo: int,
849
+ velocity_range: Tuple[int, int],
850
+ scale_name: str,
851
+ custom_scale_notes: Optional[List[int]],
852
+ num_layers: int,
853
+ instrument_preset: str,
854
+ seed: int,
855
+ quantization_grid: int = 120,
856
+ octave_range: int = 2,
857
+ dynamics_curve: str = "linear"
858
+ ) -> Tuple[str, Dict[str, Any]]:
859
+ """Generate MIDI from text input."""
860
+ # Get or create scale
861
+ if scale_name == "Custom":
862
+ if not custom_scale_notes:
863
+ raise ValueError("Custom scale requires note list")
864
+ scale = ScaleDefinition("Custom", custom_scale_notes)
865
+ else:
866
+ scale = self.scale_manager.get_scale(scale_name)
867
+ if scale is None:
868
+ raise ValueError(f"Unknown scale: {scale_name}")
869
+
870
+ # Create configuration
871
+ config = GenerationConfig(
872
+ model_name=model_name or self.DEFAULT_MODEL,
873
+ compute_mode=ComputeMode(compute_mode),
874
+ base_tempo=base_tempo,
875
+ velocity_range=velocity_range,
876
+ scale=scale,
877
+ num_layers_limit=num_layers,
878
+ seed=seed,
879
+ instrument_preset=instrument_preset,
880
+ quantization_grid=quantization_grid,
881
+ octave_range=octave_range,
882
+ dynamics_curve=dynamics_curve
883
+ )
884
+
885
+ # Validate configuration
886
+ config.validate()
887
+
888
+ # Extract latents
889
+ extractor = LatentExtractorFactory.create(config.compute_mode)
890
+ latents = extractor.extract(text, config)
891
+
892
+ # Render MIDI
893
+ renderer = MIDIRenderer(config, self.instrument_manager)
894
+ midi_bytes, metadata = renderer.render(latents)
895
+
896
+ # Save MIDI file
897
+ filename = f"llm_forest_orchestra_{uuid.uuid4().hex[:8]}.mid"
898
+ with open(filename, "wb") as f:
899
+ f.write(midi_bytes)
900
+
901
+ return filename, metadata
902
+
903
+ def save_config(self, name: str, config: GenerationConfig):
904
+ """Save a configuration for later use."""
905
+ self.saved_configs[name] = config
906
+
907
+ def load_config(self, name: str) -> Optional[GenerationConfig]:
908
+ """Load a saved configuration."""
909
+ return self.saved_configs.get(name)
910
+
911
+ def export_config(self, config: GenerationConfig, filepath: str):
912
+ """Export configuration to JSON file."""
913
+ with open(filepath, "w") as f:
914
+ json.dump(config.to_dict(), f, indent=2)
915
+
916
+ def import_config(self, filepath: str) -> GenerationConfig:
917
+ """Import configuration from JSON file."""
918
+ with open(filepath, "r") as f:
919
+ data = json.load(f)
920
+ return GenerationConfig.from_dict(data, self.scale_manager)
921
+
922
+
923
+ # Gradio UI
924
+
925
+ class GradioInterface:
926
+ """Manages the Gradio user interface."""
927
+
928
+ DESCRIPTION = """
929
+ # 🌲 LLM Forest Orchestra — Sonify Transformer Internals
930
+
931
+ Transform the hidden states and attention patterns of language models into multi-layered musical compositions.
932
+
933
+ ## 🍄 Inspiration
934
+
935
+ This project is inspired by the way **mushrooms and mycelial networks in forests**
936
+ connect plants and trees, forming a living web of communication and resource sharing.
937
+ These connections, can be turned into ethereal music.
938
+ Just as signals move through these hidden connections, transformer models also
939
+ pass hidden states and attentions across their layers. Here, those hidden
940
+ connections are translated into **music**, analogous to the forest's secret orchestra.
941
+
942
+ ## Features
943
+ - **Two compute modes**: Full model (GPU) or Mock latents (CPU-friendly)
944
+ - **Multiple musical scales**: From pentatonic to chromatic
945
+ - **Instrument presets**: Orchestral, electronic, ensemble, and more
946
+ - **Advanced controls**: Dynamics curves, quantization, velocity ranges
947
+ - **Export**: Standard MIDI files for further editing in your DAW
948
+ """
949
+
950
+ EXAMPLE_TEXT = """Joy cascades in golden waterfalls, crashing into pools of melancholy blue.
951
+ Anger burns red through veins of marble, while serenity floats on clouds of softest grey.
952
+ Love pulses in waves of crimson and rose, intertwining with longing's purple haze.
953
+ Each feeling resonates at its own frequency, painting music across the soul's canvas."""
954
+
955
+ def __init__(self, orchestra: LLMForestOrchestra):
956
+ """Initialize the interface."""
957
+ self.orchestra = orchestra
958
+
959
+ def create_interface(self) -> gr.Blocks:
960
+ """Create the Gradio interface."""
961
+ with gr.Blocks(title="LLM Forest Orchestra", theme=gr.themes.Soft()) as demo:
962
+ gr.Markdown(self.DESCRIPTION)
963
+
964
+ with gr.Tabs():
965
+ with gr.TabItem("🎵 Generate Music"):
966
+ self._create_generation_tab()
967
+
968
+ return demo
969
+
970
+ def _create_generation_tab(self):
971
+ """Create the main generation tab."""
972
+ with gr.Row():
973
+ with gr.Column(scale=1):
974
+ text_input = gr.Textbox(
975
+ value=self.EXAMPLE_TEXT,
976
+ label="Input Text",
977
+ lines=8,
978
+ placeholder="Enter text to sonify..."
979
+ )
980
+
981
+ model_name = gr.Textbox(
982
+ value=self.orchestra.DEFAULT_MODEL,
983
+ label="Hugging Face Model",
984
+ info="Model must support output_hidden_states and output_attentions"
985
+ )
986
+
987
+ compute_mode = gr.Radio(
988
+ choices=["Full model", "Mock latents"],
989
+ value="Mock latents",
990
+ label="Compute Mode",
991
+ info="Mock latents for quick CPU-only demo"
992
+ )
993
+
994
+ with gr.Row():
995
+ instrument_preset = gr.Dropdown(
996
+ choices=self.orchestra.instrument_manager.list_presets(),
997
+ value="Ensemble (melody+bass+pad etc.)",
998
+ label="Instrument Preset"
999
+ )
1000
+
1001
+ scale_choice = gr.Dropdown(
1002
+ choices=self.orchestra.scale_manager.list_scales() + ["Custom"],
1003
+ value="C pentatonic",
1004
+ label="Musical Scale"
1005
+ )
1006
+
1007
+ custom_scale = gr.Textbox(
1008
+ value="",
1009
+ label="Custom Scale Notes",
1010
+ placeholder="60,62,65,67,70",
1011
+ visible=False
1012
+ )
1013
+
1014
+ with gr.Row():
1015
+ base_tempo = gr.Slider(
1016
+ 120, 960,
1017
+ value=480,
1018
+ step=1,
1019
+ label="Tempo (ticks per beat)"
1020
+ )
1021
+
1022
+ num_layers = gr.Slider(
1023
+ 1, 6,
1024
+ value=6,
1025
+ step=1,
1026
+ label="Max Layers"
1027
+ )
1028
+
1029
+ with gr.Row():
1030
+ velocity_low = gr.Slider(
1031
+ 1, 126,
1032
+ value=40,
1033
+ step=1,
1034
+ label="Min Velocity"
1035
+ )
1036
+
1037
+ velocity_high = gr.Slider(
1038
+ 2, 127,
1039
+ value=90,
1040
+ step=1,
1041
+ label="Max Velocity"
1042
+ )
1043
+
1044
+ seed = gr.Number(
1045
+ value=42,
1046
+ precision=0,
1047
+ label="Random Seed"
1048
+ )
1049
+
1050
+ generate_btn = gr.Button(
1051
+ "🎼 Generate MIDI",
1052
+ variant="primary",
1053
+ size="lg"
1054
+ )
1055
+
1056
+ with gr.Column(scale=1):
1057
+ midi_output = gr.File(
1058
+ label="Generated MIDI File",
1059
+ file_types=[".mid", ".midi"]
1060
+ )
1061
+
1062
+ stats_display = gr.Markdown(label="Quick Stats")
1063
+
1064
+ metadata_json = gr.Code(
1065
+ label="Metadata (JSON)",
1066
+ language="json"
1067
+ )
1068
+
1069
+ with gr.Row():
1070
+ play_instructions = gr.Markdown(
1071
+ """
1072
+ ### 🎧 How to Play
1073
+ 1. Download the MIDI file
1074
+ 2. Open in any DAW or MIDI player
1075
+ 3. Adjust instruments and effects as desired
1076
+ 4. Export to audio format
1077
+ """
1078
+ )
1079
+
1080
+ # Set up interactions
1081
+ def update_custom_scale_visibility(choice):
1082
+ return gr.update(visible=(choice == "Custom"))
1083
+
1084
+ scale_choice.change(
1085
+ update_custom_scale_visibility,
1086
+ inputs=[scale_choice],
1087
+ outputs=[custom_scale]
1088
+ )
1089
+
1090
+ def generate_wrapper(
1091
+ text, model_name, compute_mode, base_tempo,
1092
+ velocity_low, velocity_high, scale_choice,
1093
+ custom_scale, num_layers, instrument_preset, seed
1094
+ ):
1095
+ """Wrapper for generation with error handling."""
1096
+ try:
1097
+ # Parse custom scale if needed
1098
+ custom_notes = None
1099
+ if scale_choice == "Custom" and custom_scale:
1100
+ custom_notes = [int(x.strip()) for x in custom_scale.split(",")]
1101
+
1102
+ # Generate
1103
+ filename, metadata = self.orchestra.generate(
1104
+ text=text,
1105
+ model_name=model_name,
1106
+ compute_mode=compute_mode,
1107
+ base_tempo=int(base_tempo),
1108
+ velocity_range=(int(velocity_low), int(velocity_high)),
1109
+ scale_name=scale_choice,
1110
+ custom_scale_notes=custom_notes,
1111
+ num_layers=int(num_layers),
1112
+ instrument_preset=instrument_preset,
1113
+ seed=int(seed)
1114
+ )
1115
+
1116
+ # Format stats
1117
+ stats = metadata.get("stats", {})
1118
+ stats_text = f"""
1119
+ ### Generation Statistics
1120
+ - **Layers Used**: {stats.get('num_layers', 'N/A')}
1121
+ - **Tokens Processed**: {stats.get('num_tokens', 'N/A')}
1122
+ - **Total Notes**: {stats.get('total_notes', 'N/A')}
1123
+ - **Notes per Layer**: {stats.get('notes_per_layer', [])}
1124
+ - **Scale**: {stats.get('scale', [])}
1125
+ - **Tempo**: {stats.get('tempo_ticks_per_beat', 'N/A')} ticks/beat
1126
+ """
1127
+
1128
+ return filename, stats_text, json.dumps(metadata, indent=2)
1129
+
1130
+ except Exception as e:
1131
+ error_msg = f"### ❌ Error\n{str(e)}"
1132
+ return None, error_msg, json.dumps({"error": str(e)}, indent=2)
1133
+
1134
+ generate_btn.click(
1135
+ fn=generate_wrapper,
1136
+ inputs=[
1137
+ text_input, model_name, compute_mode, base_tempo,
1138
+ velocity_low, velocity_high, scale_choice,
1139
+ custom_scale, num_layers, instrument_preset, seed
1140
+ ],
1141
+ outputs=[midi_output, stats_display, metadata_json]
1142
+ )
1143
+
1144
+
1145
+ # Main Entry Point
1146
+
1147
+ def main():
1148
+ """Main entry point for the application."""
1149
+ # Initialize orchestra
1150
+ orchestra = LLMForestOrchestra()
1151
+
1152
+ # Create interface
1153
+ interface = GradioInterface(orchestra)
1154
+ demo = interface.create_interface()
1155
+
1156
+ # Launch
1157
+ demo.launch()
1158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1159
 
1160
  if __name__ == "__main__":
1161
+ main()