Locutusque commited on
Commit
f47aa49
·
verified ·
1 Parent(s): 2ed725b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +392 -0
app.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import gc
4
+ import math
5
+ import time
6
+ import uuid
7
+ import spaces
8
+ import random
9
+ from dataclasses import dataclass
10
+ from typing import Dict, List, Tuple, Optional
11
+
12
+ import gradio as gr
13
+ import numpy as np
14
+ import torch
15
+ from transformers import AutoModel, AutoTokenizer
16
+
17
+ import mido
18
+ from mido import Message, MidiFile, MidiTrack
19
+
20
+ # -------------------- Defaults & Music Helpers --------------------
21
+
22
+ DEFAULT_MODEL = "unsloth/Qwen3-14B-Base"
23
+
24
+ SCALES = {
25
+ "C pentatonic": [60, 62, 65, 67, 70, 72, 74, 77],
26
+ "C major": [60, 62, 64, 65, 67, 69, 71, 72],
27
+ "A minor": [57, 59, 60, 62, 64, 65, 67, 69],
28
+ "Custom (comma-separated MIDI notes)": [],
29
+ }
30
+
31
+ LAYER_INSTRUMENT_PRESETS = {
32
+ "Ensemble (melody+bass+pad etc.)": {
33
+ 0: (0, 'melody'),
34
+ 1: (33, 'bass'),
35
+ 2: (46, 'harmony'),
36
+ 3: (48, 'pad'),
37
+ 4: (11, 'accent'),
38
+ 5: (89, 'atmosphere'),
39
+ },
40
+ "Piano Trio (melody+bass+harmony)": {
41
+ 0: (0, 'melody'),
42
+ 1: (33, 'bass'),
43
+ 2: (0, 'harmony'),
44
+ 3: (48, 'pad'),
45
+ 4: (0, 'accent'),
46
+ 5: (0, 'atmosphere'),
47
+ },
48
+ "Pads & Atmos": {
49
+ 0: (48, 'pad'),
50
+ 1: (48, 'pad'),
51
+ 2: (89, 'atmosphere'),
52
+ 3: (89, 'atmosphere'),
53
+ 4: (46, 'harmony'),
54
+ 5: (11, 'accent'),
55
+ },
56
+ }
57
+
58
+ @dataclass
59
+ class GenConfig:
60
+ model_name: str
61
+ compute_mode: str # "Full model" or "Mock latents"
62
+ base_tempo: int
63
+ velocity_range: Tuple[int, int]
64
+ scale: List[int]
65
+ num_layers_limit: int
66
+ seed: int
67
+
68
+ # --- Core math helpers ---
69
+
70
+ def entropy(p: np.ndarray) -> float:
71
+ p = p / (p.sum() + 1e-9)
72
+ return float(-np.sum(p * np.log2(p + 1e-9)))
73
+
74
+ def quantize_time(time_val: int, grid: int = 120) -> int:
75
+ return int(round(time_val / grid) * grid)
76
+
77
+ def norm_to_scale(val: float, scale: np.ndarray, octave_range: int = 2) -> int:
78
+ octave = int(abs(val) * octave_range) * 12
79
+ note_idx = int(abs(val * 100) % len(scale))
80
+ return int(scale[note_idx] + octave)
81
+
82
+ ROLE_FREQS = {
83
+ 'melody': 2.0,
84
+ 'bass': 0.5,
85
+ 'harmony': 1.5,
86
+ 'pad': 0.25,
87
+ 'accent': 3.0,
88
+ 'atmosphere': 0.33
89
+ }
90
+
91
+ ROLE_WEIGHTS = {
92
+ 'melody': np.array([0.4, 0.2, 0.2, 0.1, 0.1]),
93
+ 'bass': np.array([0.1, 0.4, 0.1, 0.3, 0.1]),
94
+ 'harmony': np.array([0.2, 0.2, 0.3, 0.2, 0.1]),
95
+ 'pad': np.array([0.1, 0.3, 0.1, 0.1, 0.4]),
96
+ 'accent': np.array([0.5, 0.1, 0.2, 0.1, 0.1]),
97
+ 'atmosphere': np.array([0.1, 0.2, 0.1, 0.2, 0.4])
98
+ }
99
+
100
+ def create_note_probability(layer_idx, token_idx, attention_val, hidden_state, num_tokens, role: str):
101
+ base_prob = 1 / (1 + np.exp(-10 * (attention_val - 0.5)))
102
+ temporal_factor = 0.5 + 0.5 * np.sin(2 * np.pi * ROLE_FREQS[role] * token_idx / max(1, num_tokens))
103
+ energy = np.linalg.norm(hidden_state)
104
+ energy_factor = np.tanh(energy / 10)
105
+ local_variance = np.var(hidden_state)
106
+ variance_factor = 1 - np.exp(-local_variance)
107
+ state_entropy = entropy(np.abs(hidden_state))
108
+ max_entropy = np.log2(max(2, hidden_state.shape[0]))
109
+ entropy_factor = state_entropy / max_entropy
110
+ factors = np.array([base_prob, temporal_factor, energy_factor, variance_factor, entropy_factor])
111
+ weights = ROLE_WEIGHTS[role]
112
+ combined_prob = float(np.dot(weights, factors))
113
+ noise_seed = layer_idx * 1000 + token_idx
114
+ noise = 0.1 * (np.sin(noise_seed * 0.1) + np.cos(noise_seed * 0.23)) / 2
115
+ final_prob = (combined_prob + noise) ** 1.5
116
+ return float(np.clip(final_prob, 0, 1))
117
+
118
+ def should_play_note_stochastic(layer_idx, token_idx, attention_val, hidden_state, num_tokens, role: str, history: Dict[int,int]):
119
+ prob = create_note_probability(layer_idx, token_idx, attention_val, hidden_state, num_tokens, role)
120
+ if layer_idx in history:
121
+ last_played = history[layer_idx]
122
+ silence_duration = token_idx - last_played
123
+ prob *= (1 + np.tanh(silence_duration / 5) * 0.5)
124
+ play_note = np.random.random() < prob
125
+ if play_note:
126
+ history[layer_idx] = token_idx
127
+ return play_note
128
+
129
+ # -------------------- Model / Latents --------------------
130
+
131
+ @dataclass
132
+ class Latents:
133
+ hidden_states: List[torch.Tensor]
134
+ attentions: List[torch.Tensor]
135
+ num_layers: int
136
+ num_tokens: int
137
+
138
+ @spaces.GPU(duration=45)
139
+ def get_latents(text: str, model_name: str, compute_mode: str, max_layers: int, progress=gr.Progress(track_tqdm=True)) -> Latents:
140
+ if compute_mode == "Mock latents":
141
+ # Fast path for Spaces without big GPUs
142
+ tokens = max(16, min(128, len(text.split()) * 4))
143
+ layers = min(max_layers, 6)
144
+ hidden_states = [torch.randn(1, tokens, 128) for _ in range(layers)]
145
+ attentions = [torch.rand(1, 8, tokens, tokens) for _ in range(layers)]
146
+ return Latents(hidden_states=hidden_states, attentions=attentions, num_layers=layers, num_tokens=tokens)
147
+
148
+ # Full model path
149
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
150
+ if tokenizer.pad_token is None and tokenizer.eos_token is not None:
151
+ tokenizer.pad_token = tokenizer.eos_token
152
+
153
+ # Try different memory-friendly loading strategies
154
+ load_kwargs = dict(
155
+ output_hidden_states=True,
156
+ output_attentions=True,
157
+ device_map="cuda",
158
+ )
159
+
160
+ # dtype heuristics
161
+ try:
162
+ load_kwargs["torch_dtype"] = torch.bfloat16 if torch.cuda.is_available() else torch.float32
163
+ except Exception:
164
+ pass
165
+
166
+ model = AutoModel.from_pretrained(model_name, **load_kwargs)
167
+
168
+ inputs = tokenizer(text, return_tensors="pt")
169
+ device = next(model.parameters()).device
170
+ inputs = {k: v.to(device) for k, v in inputs.items()}
171
+
172
+ with torch.no_grad():
173
+ outputs = model(**inputs)
174
+ hidden_states = list(outputs.hidden_states)
175
+ attentions = list(outputs.attentions)
176
+
177
+ # Move to CPU numpy-friendly dtype to free VRAM
178
+ hidden_states = [hs.to("cpu") for hs in hidden_states]
179
+ attentions = [att.to("cpu") for att in attentions]
180
+
181
+ # Trim layers
182
+ layers = min(max_layers, 6, len(hidden_states))
183
+ tokens = hidden_states[0].shape[1]
184
+
185
+ # Clean up VRAM
186
+ try:
187
+ del model
188
+ if torch.cuda.is_available():
189
+ torch.cuda.empty_cache()
190
+ except Exception:
191
+ pass
192
+
193
+ return Latents(hidden_states=hidden_states[:layers], attentions=attentions[:layers], num_layers=layers, num_tokens=tokens)
194
+
195
+ # -------------------- MIDI Rendering --------------------
196
+
197
+ def render_midi(latents: Latents, scale_notes: List[int], base_tempo: int, velocity_range: Tuple[int, int], preset_name: str, seed: int) -> Tuple[bytes, Dict]:
198
+ np.random.seed(seed)
199
+ random.seed(seed)
200
+
201
+ scale = np.array(scale_notes, dtype=int)
202
+ num_layers = latents.num_layers
203
+ num_tokens = latents.num_tokens
204
+ hidden_states = [hs.numpy() if isinstance(hs, torch.Tensor) else hs for hs in latents.hidden_states]
205
+ attentions = [att.numpy() if isinstance(att, torch.Tensor) else att for att in latents.attentions]
206
+
207
+ layer_instruments = LAYER_INSTRUMENT_PRESETS[preset_name]
208
+
209
+ mid = MidiFile()
210
+ tracks: List[MidiTrack] = []
211
+ for ch in range(num_layers):
212
+ track = MidiTrack()
213
+ mid.tracks.append(track)
214
+ tracks.append(track)
215
+ instrument = layer_instruments.get(ch, (0, 'melody'))[0]
216
+ track.append(Message('program_change', program=int(instrument), time=0, channel=ch))
217
+
218
+ history: Dict[int, int] = {}
219
+ current_time = [0] * num_layers
220
+ notes_count = [0] * num_layers
221
+
222
+ for token_idx in range(num_tokens):
223
+ if token_idx > 0 and token_idx % 4 == 0:
224
+ for layer_idx in range(num_layers):
225
+ current_time[layer_idx] += base_tempo
226
+
227
+ pan = 64 + int(32 * np.sin(token_idx * math.pi / max(1, num_tokens)))
228
+
229
+ for layer_idx in range(num_layers):
230
+ role = layer_instruments.get(layer_idx, (0, 'melody'))[1]
231
+
232
+ attn_matrix = attentions[min(layer_idx, len(attentions) - 1)][0, :, token_idx, :]
233
+ attention_strength = float(np.mean(attn_matrix))
234
+ layer_vec = hidden_states[layer_idx][0, token_idx]
235
+
236
+ if not should_play_note_stochastic(layer_idx, token_idx, attention_strength, layer_vec, num_tokens, role, history):
237
+ continue
238
+
239
+ if role == 'melody':
240
+ note = norm_to_scale(layer_vec[0], scale, octave_range=1)
241
+ notes_to_play = [note]
242
+ elif role == 'bass':
243
+ note = norm_to_scale(layer_vec[0], scale, octave_range=0) - 12
244
+ notes_to_play = [note]
245
+ elif role == 'harmony':
246
+ notes_to_play = [norm_to_scale(layer_vec[i], scale, octave_range=1) for i in range(0, min(2, len(layer_vec)), 1)]
247
+ elif role == 'pad':
248
+ notes_to_play = [norm_to_scale(layer_vec[i], scale, octave_range=1) for i in range(0, min(3, len(layer_vec)), 2)]
249
+ elif role == 'accent':
250
+ note = norm_to_scale(layer_vec[0], scale, octave_range=2) + 12
251
+ notes_to_play = [note]
252
+ else:
253
+ notes_to_play = [norm_to_scale(layer_vec[i], scale, octave_range=1) for i in range(0, min(2, len(layer_vec)), 3)]
254
+
255
+ base_velocity = int(attention_strength * (velocity_range[1] - velocity_range[0]) + velocity_range[0])
256
+ if role == 'melody':
257
+ velocity = min(base_velocity + 10, 127)
258
+ elif role == 'bass':
259
+ velocity = base_velocity
260
+ elif role == 'accent':
261
+ velocity = min(base_velocity + 20, 127)
262
+ else:
263
+ velocity = max(base_velocity - 10, 20)
264
+
265
+ if role in ['pad', 'atmosphere']:
266
+ duration = base_tempo * 4
267
+ elif role == 'bass':
268
+ duration = base_tempo
269
+ else:
270
+ try:
271
+ dur_factor = entropy(attn_matrix.mean(axis=0)) / (np.log2(attn_matrix.shape[-1]) + 1e-9)
272
+ except Exception:
273
+ dur_factor = 0.5
274
+ duration = quantize_time(int(base_tempo * (0.5 + dur_factor * 1.5)))
275
+
276
+ for note in notes_to_play:
277
+ note = max(21, min(108, int(note)))
278
+ tracks[layer_idx].append(Message('note_on', note=note, velocity=velocity, time=current_time[layer_idx], channel=layer_idx))
279
+ tracks[layer_idx].append(Message('note_off', note=note, velocity=0, time=duration, channel=layer_idx))
280
+ current_time[layer_idx] = 0
281
+ notes_count[layer_idx] += 1
282
+
283
+ if token_idx == 0:
284
+ tracks[layer_idx].append(Message('control_change', control=10, value=pan, time=0, channel=layer_idx))
285
+
286
+ # Save to bytes
287
+ bio = io.BytesIO()
288
+ mid.save(file=bio)
289
+ bio.seek(0)
290
+
291
+ meta = {
292
+ "num_layers": num_layers,
293
+ "num_tokens": num_tokens,
294
+ "notes_per_layer": notes_count,
295
+ "total_notes": int(sum(notes_count)),
296
+ "tempo_ticks_per_beat": int(base_tempo),
297
+ "scale": list(map(int, scale.tolist())),
298
+ }
299
+ return bio.read(), meta
300
+
301
+ # -------------------- Gradio UI --------------------
302
+
303
+ DESCRIPTION = """
304
+ # LLM Forest Orchestra — Sonify Transformer Internals
305
+ Turn hidden states and attentions into a multi-track MIDI composition.
306
+
307
+ - **Two compute modes**: *Full model* (loads a HF model and extracts latents) or *Mock latents* (quick demo with synthetic tensors — great for CPU-only Spaces).
308
+ - Choose **scale**, **tempo**, **velocity range**, and **instrument/role preset**.
309
+ - Exports a **MIDI** you can arrange further in your DAW.
310
+ """
311
+
312
+ EXAMPLE_TEXT = """Joy cascades in golden waterfalls, crashing into pools of melancholy blue.
313
+ Anger burns red through veins of marble, while serenity floats on clouds of softest grey.
314
+ Love pulses in waves of crimson and rose, intertwining with longing's purple haze.
315
+ Each feeling resonates at its own frequency, painting music across the soul's canvas.
316
+ """
317
+
318
+ def parse_scale(selection: str, custom: str) -> List[int]:
319
+ if selection == "Custom (comma-separated MIDI notes)":
320
+ try:
321
+ return [int(x.strip()) for x in custom.split(",") if x.strip()]
322
+ except Exception:
323
+ return SCALES["C pentatonic"]
324
+ return SCALES[selection] if SCALES[selection] else SCALES["C pentatonic"]
325
+
326
+ def generate(text, model_name, compute_mode, base_tempo, velocity_low, velocity_high, scale_choice, custom_scale, num_layers, preset, seed):
327
+ scale = parse_scale(scale_choice, custom_scale)
328
+ cfg = GenConfig(
329
+ model_name=model_name or DEFAULT_MODEL,
330
+ compute_mode=compute_mode,
331
+ base_tempo=int(base_tempo),
332
+ velocity_range=(int(velocity_low), int(velocity_high)),
333
+ scale=scale,
334
+ num_layers_limit=int(num_layers),
335
+ seed=int(seed),
336
+ )
337
+
338
+ # Get latents
339
+ latents = get_latents(text, cfg.model_name, cfg.compute_mode, cfg.num_layers_limit)
340
+
341
+ # Render MIDI
342
+ midi_bytes, meta = render_midi(latents, cfg.scale, cfg.base_tempo, cfg.velocity_range, preset, cfg.seed)
343
+
344
+ # Persist to a file for download
345
+ out_name = f"llm_forest_orchestra_{uuid.uuid4().hex[:8]}.mid"
346
+ with open(out_name, "wb") as f:
347
+ f.write(midi_bytes)
348
+
349
+ # Prepare quick stats
350
+ stats = (
351
+ f"Layers: {meta['num_layers']} | Tokens: {meta['num_tokens']} | "
352
+ f"Total notes: {meta['total_notes']} | Scale: {meta['scale']} | "
353
+ f"Tempo (ticks/beat): {meta['tempo_ticks_per_beat']}"
354
+ )
355
+
356
+ return out_name, stats, json.dumps(meta, indent=2)
357
+
358
+ with gr.Blocks(title="LLM Forest Orchestra — MIDI from Transformer Internals") as demo:
359
+ gr.Markdown(DESCRIPTION)
360
+
361
+ with gr.Row():
362
+ with gr.Column():
363
+ text = gr.Textbox(value=EXAMPLE_TEXT, label="Input text", lines=8)
364
+ model_name = gr.Textbox(value=DEFAULT_MODEL, label="HF model (base) to probe", info="Should support output_hidden_states & output_attentions")
365
+ compute_mode = gr.Radio(choices=["Mock latents", "Full model"], value="Full model", label="Compute mode")
366
+ preset = gr.Dropdown(choices=list(LAYER_INSTRUMENT_PRESETS.keys()), value="Ensemble (melody+bass+pad etc.)", label="Instrument/Role preset")
367
+ with gr.Row():
368
+ base_tempo = gr.Slider(120, 960, value=480, step=1, label="Ticks per beat (tempo grid)")
369
+ num_layers = gr.Slider(1, 6, value=6, step=1, label="Max layers to use")
370
+ with gr.Row():
371
+ velocity_low = gr.Slider(1, 126, value=40, step=1, label="Velocity min")
372
+ velocity_high = gr.Slider(2, 127, value=90, step=1, label="Velocity max")
373
+ with gr.Row():
374
+ scale_choice = gr.Dropdown(choices=list(SCALES.keys()), value="C pentatonic", label="Scale")
375
+ custom_scale = gr.Textbox(value="", label="Custom scale notes (e.g. 60,62,65,67)")
376
+ seed = gr.Number(value=42, precision=0, label="Random seed")
377
+
378
+ btn = gr.Button("Generate MIDI", variant="primary")
379
+
380
+ with gr.Column():
381
+ midi_file = gr.File(label="MIDI output (.mid)")
382
+ stats = gr.Markdown("")
383
+ meta_json = gr.Code(label="Meta (JSON)")
384
+
385
+ btn.click(
386
+ fn=generate,
387
+ inputs=[text, model_name, compute_mode, base_tempo, velocity_low, velocity_high, scale_choice, custom_scale, num_layers, preset, seed],
388
+ outputs=[midi_file, stats, meta_json]
389
+ )
390
+
391
+ if __name__ == "__main__":
392
+ demo.launch()