SHIKARICHACHA commited on
Commit
6b126b7
Β·
verified Β·
1 Parent(s): 981df93

Upload 4 files

Browse files
Files changed (3) hide show
  1. Dockerfile +6 -1
  2. README.md +1 -1
  3. app.py +922 -258
Dockerfile CHANGED
@@ -5,6 +5,9 @@ RUN apt-get update && apt-get install -y \
5
  fluidsynth \
6
  libsndfile1 \
7
  wget \
 
 
 
8
  && rm -rf /var/lib/apt/lists/*
9
 
10
  # Create app directory
@@ -27,8 +30,10 @@ RUN mkdir -p /app/soundfonts && \
27
  wget -O /app/soundfonts/Clarinet.sf2 https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2 && \
28
  wget -O /app/soundfonts/Flute.sf2 https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2
29
 
30
- # Create static directory for audio files
31
  RUN mkdir -p /app/static
 
 
32
 
33
  # Set environment variables
34
  ENV PYTHONUNBUFFERED=1
 
5
  fluidsynth \
6
  libsndfile1 \
7
  wget \
8
+ libfreetype6-dev \
9
+ libpng-dev \
10
+ pkg-config \
11
  && rm -rf /var/lib/apt/lists/*
12
 
13
  # Create app directory
 
30
  wget -O /app/soundfonts/Clarinet.sf2 https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2 && \
31
  wget -O /app/soundfonts/Flute.sf2 https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2
32
 
33
+ # Create necessary directories
34
  RUN mkdir -p /app/static
35
+ RUN mkdir -p /app/exercise_library
36
+ RUN mkdir -p /app/temp_audio
37
 
38
  # Set environment variables
39
  ENV PYTHONUNBUFFERED=1
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Harmony Hub Ai Music Generation
3
  emoji: πŸ‘€
4
  colorFrom: indigo
5
  colorTo: indigo
 
1
  ---
2
+ title: Ai Music Generation Version 2
3
  emoji: πŸ‘€
4
  colorFrom: indigo
5
  colorTo: indigo
app.py CHANGED
@@ -1,21 +1,22 @@
1
  """
2
- Adaptive Music Exercise Generator (Integer Durations Only)
3
- ==========================================================
4
-
5
- Generates custom musical exercises with LLM using integer durations only (1=quarter note)
6
- Perfectly fits user-specified number of measures and time signature
7
-
8
  Major updates:
9
- - Fixed difficulty level implementation
10
- - Added duration sum display
11
- - Strict integer durations only (1=quarter note)
12
- - Simplified duration scaling
13
- - Strengthened LLM prompts for better JSON compliance
14
- - More robust note name sanitization
 
 
 
15
  """
16
 
17
  # -----------------------------------------------------------------------------
18
- # 1. Runtime-time package installation
19
  # -----------------------------------------------------------------------------
20
  import sys
21
  import subprocess
@@ -31,7 +32,8 @@ def install(packages: List[str]):
31
 
32
  install([
33
  "mido", "midi2audio", "pydub", "gradio",
34
- "requests", "os", "re", "json", "tempfile", "shutil"
 
35
  ])
36
 
37
  # -----------------------------------------------------------------------------
@@ -44,18 +46,28 @@ import tempfile
44
  import mido
45
  from mido import Message, MidiFile, MidiTrack, MetaMessage
46
  import re
47
- import os
48
- import shutil
49
- import subprocess as sp
50
  from midi2audio import FluidSynth
51
  from pydub import AudioSegment
52
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  # -----------------------------------------------------------------------------
55
- # 3. Configuration & constants
56
  # -----------------------------------------------------------------------------
57
  MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions"
58
- MISTRAL_API_KEY = "yQdfM8MLbX9uhInQ7id4iUTwN4h4pDLX"
59
 
60
  SOUNDFONT_URLS = {
61
  "Trumpet": "https://github.com/FluidSynth/fluidsynth/raw/master/sf2/Trumpet.sf2",
@@ -65,8 +77,9 @@ SOUNDFONT_URLS = {
65
  "Flute": "https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2",
66
  }
67
 
68
- TICKS_PER_BEAT = 480
69
- TICKS_PER_UNIT = TICKS_PER_BEAT # 1 unit = 1 quarter note
 
70
 
71
  if not os.path.exists('/usr/bin/fluidsynth'):
72
  try:
@@ -75,75 +88,167 @@ if not os.path.exists('/usr/bin/fluidsynth'):
75
  print("Could not install FluidSynth automatically. Please install it manually.")
76
 
77
  os.makedirs("static", exist_ok=True)
 
 
78
 
79
  # -----------------------------------------------------------------------------
80
- # 4. Music theory helpers (note names β†”οΈŽ MIDI numbers)
81
  # -----------------------------------------------------------------------------
82
- NOTE_MAP = {"C":0,"C#":1,"DB":1,"D":2,"D#":3,"EB":3,"E":4,"F":5,"F#":6,"GB":6,
83
- "G":7,"G#":8,"AB":8,"A":9,"A#":10,"BB":10,"B":11}
 
 
 
 
 
 
 
 
84
 
85
  INSTRUMENT_PROGRAMS: Dict[str, int] = {
86
- "Piano": 0, "Trumpet": 56, "Violin": 40,
87
- "Clarinet": 71, "Flute": 73,
88
  }
89
 
90
- def sanitise_note_name(n: str) -> str:
91
- """Convert common variations to canonical form."""
92
- n = n.strip().upper()
93
- n = re.sub(r'\bFLAT\b', 'B', n, flags=re.I)
94
- n = re.sub(r'\bSHARP\b', '#', n, flags=re.I)
95
- n = re.sub(r'([A-G])\s*-\s*FLAT', r'\1B', n, flags=re.I)
96
- n = re.sub(r'([A-G])\s*-\s*SHARP', r'\1#', n, flags=re.I)
97
- return n
98
 
99
  def note_name_to_midi(note: str) -> int:
100
- note = sanitise_note_name(note)
101
- match = re.fullmatch(r"([A-G][#B]?)(\d)", note)
 
 
 
102
  if not match:
103
  raise ValueError(f"Invalid note: {note}")
104
- pitch, octave = match.groups()
 
 
 
 
 
 
 
 
 
 
105
  if pitch not in NOTE_MAP:
106
  raise ValueError(f"Invalid pitch: {pitch}")
107
- return NOTE_MAP[pitch] + (int(octave) + 1) * 12
 
 
 
 
 
 
 
 
108
 
109
  # -----------------------------------------------------------------------------
110
- # 5. Integer duration scaling
111
  # -----------------------------------------------------------------------------
112
  def scale_json_durations(json_data, target_units: int) -> list:
113
- """Convert durations to integers and scale to match target exactly"""
114
- ints = [[n, max(1, int(round(d)))] for n, d in json_data]
115
- total = sum(d for _, d in ints)
116
- deficit = target_units - total
117
-
118
- if deficit > 0:
119
- for i in range(deficit):
120
- ints[i % len(ints)][1] += 1
121
- elif deficit < 0:
122
- for i in range(-deficit):
123
- if ints[i % len(ints)][1] > 1:
124
- ints[i % len(ints)][1] -= 1
125
- return ints
 
 
 
 
 
 
 
126
 
127
  # -----------------------------------------------------------------------------
128
- # 6. MIDI from scaled JSON (integer durations)
129
  # -----------------------------------------------------------------------------
130
- def json_to_midi(json_data: list, instrument: str, tempo: int,
131
- time_signature: str, measures: int) -> MidiFile:
132
  mid = MidiFile(ticks_per_beat=TICKS_PER_BEAT)
133
  track = MidiTrack(); mid.tracks.append(track)
134
  program = INSTRUMENT_PROGRAMS.get(instrument, 56)
135
- num, denom = map(int, time_signature.split('/'))
136
 
137
- track.append(MetaMessage('time_signature', numerator=num, denominator=denom, time=0))
 
 
 
138
  track.append(MetaMessage('set_tempo', tempo=mido.bpm2tempo(tempo), time=0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  track.append(Message('program_change', program=program, time=0))
140
 
141
- for note_name, dur_units in json_data:
142
- note_num = note_name_to_midi(note_name)
143
- ticks = max(int(dur_units * TICKS_PER_UNIT), 1)
144
- vel = random.randint(60, 100)
145
- track.append(Message('note_on', note=note_num, velocity=vel, time=0))
146
- track.append(Message('note_off', note=note_num, velocity=vel, time=ticks))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  return mid
148
 
149
  # -----------------------------------------------------------------------------
@@ -166,58 +271,127 @@ def midi_to_mp3(midi_obj: MidiFile, instrument: str = "Trumpet") -> Tuple[str, f
166
  wav_path = mid_file.name.replace(".mid", ".wav")
167
  mp3_path = mid_file.name.replace(".mid", ".mp3")
168
  sf2_path = get_soundfont(instrument)
169
-
170
  try:
171
  sp.run([
172
  'fluidsynth', '-ni', sf2_path, mid_file.name,
173
- '-F', wav_path, '-r', '44100'
174
  ], check=True, capture_output=True)
175
  except Exception:
176
- fs = FluidSynth(sf2_path, sample_rate=44100)
177
  fs.midi_to_audio(mid_file.name, wav_path)
178
-
179
- sound = AudioSegment.from_wav(wav_path)
180
- sound.export(mp3_path, format="mp3")
181
- static_mp3_path = os.path.join('static', os.path.basename(mp3_path))
182
- shutil.move(mp3_path, static_mp3_path)
183
-
184
- # Cleanup
185
- for f in [mid_file.name, wav_path]:
186
- try:
187
- os.remove(f)
188
- except FileNotFoundError:
189
- pass
190
-
191
- return static_mp3_path, sound.duration_seconds
 
 
192
 
193
  # -----------------------------------------------------------------------------
194
- # 8. Fallback patterns for error recovery
195
  # -----------------------------------------------------------------------------
196
  def get_fallback_exercise(instrument: str, level: str, key: str,
197
  time_sig: str, measures: int) -> str:
198
- patterns = {
199
- "Trumpet": ["C4", "D4", "E4", "G4"],
200
- "Piano": ["C4", "E4", "G4", "C5"],
201
- "Violin": ["G4", "A4", "B4", "D5"],
202
- "Clarinet": ["E4", "F4", "G4", "Bb4"],
203
- "Flute": ["A4", "B4", "C5", "E5"],
 
 
204
  }
205
- pat = patterns.get(instrument, patterns["Trumpet"])
206
- numerator = int(time_sig.split('/')[0])
207
- target = measures * numerator
208
- notes, durs = [], []
209
- i = 0
210
- while sum(durs) < target:
211
- notes.append(pat[i % len(pat)])
212
- durs.append(1)
213
- i += 1
214
- if sum(durs) > target:
215
- durs[-1] -= sum(durs) - target
216
- return json.dumps([[n,d] for n,d in zip(notes,durs)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
- # -----------------------------------------------------------------------------
219
- # 9. Difficulty level helpers (FIXED)
220
- # -----------------------------------------------------------------------------
221
  def get_style_based_on_level(level: str) -> str:
222
  styles = {
223
  "Beginner": ["simple", "legato", "stepwise"],
@@ -235,7 +409,7 @@ def get_technique_based_on_level(level: str) -> str:
235
  return random.choice(techniques.get(level, ["with slurs"]))
236
 
237
  # -----------------------------------------------------------------------------
238
- # 10. Mistral API with strengthened integer duration prompts (FIXED)
239
  # -----------------------------------------------------------------------------
240
  def query_mistral(prompt: str, instrument: str, level: str, key: str,
241
  time_sig: str, measures: int) -> str:
@@ -243,117 +417,213 @@ def query_mistral(prompt: str, instrument: str, level: str, key: str,
243
  "Authorization": f"Bearer {MISTRAL_API_KEY}",
244
  "Content-Type": "application/json",
245
  }
246
- numerator = int(time_sig.split('/')[0])
247
- target_units = measures * numerator
248
 
249
- strict_format = (
250
- "Use ONLY standard note names like 'C4', 'F#5', 'Bb3'. "
251
- "Use ONLY integer durations representing quarter-note beats: "
252
- "1 = quarter, 2 = half, 4 = whole. "
253
- f"Sum MUST equal exactly {target_units}. "
254
- "Output ONLY a JSON array of [note, integer_duration] pairs. "
255
- "No prose, no explanation."
 
 
256
  )
257
-
 
 
 
 
258
  if prompt.strip():
259
- user_prompt = f"{prompt}\n\n{strict_format}"
 
 
 
260
  else:
261
- # FIXED: Incorporate difficulty level
262
  style = get_style_based_on_level(level)
263
- tech = get_technique_based_on_level(level)
 
 
 
 
 
 
 
 
 
 
264
  user_prompt = (
265
- f"Create a {style} {instrument.lower()} exercise in {key}, {time_sig}, {tech}. "
266
- f"{strict_format}"
 
 
 
 
 
267
  )
268
-
269
  payload = {
270
  "model": "mistral-medium",
271
  "messages": [
272
- {"role": "system", "content": f"You are an expert {instrument.lower()} teacher."},
273
- {"role": "user", "content": user_prompt}
274
  ],
275
- "temperature": 0.6,
276
- "max_tokens": 800
 
 
 
277
  }
278
-
279
  try:
280
  response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
281
  response.raise_for_status()
282
- return response.json()["choices"][0]["message"]["content"].strip()
 
283
  except Exception as e:
284
  print(f"Error querying Mistral API: {e}")
285
  return get_fallback_exercise(instrument, level, key, time_sig, measures)
286
 
287
  # -----------------------------------------------------------------------------
288
- # 11. Robust JSON parsing
289
  # -----------------------------------------------------------------------------
290
  def safe_parse_json(text: str) -> Optional[list]:
291
  try:
292
- text = text.replace("'", '"')
293
- match = re.search(r"\[(\s*\[.*?\]\s*,?)*\]", text, re.DOTALL)
294
- return json.loads(match.group(0) if match else json.loads(text))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  except Exception as e:
296
  print(f"JSON parsing error: {e}\nRaw text: {text}")
297
  return None
298
 
299
  # -----------------------------------------------------------------------------
300
- # 12. Main generation workflow
301
  # -----------------------------------------------------------------------------
302
- def generate_exercise(instrument: str, level: str, key: str, tempo: int,
303
- time_signature: str, measures: int, custom_prompt: str,
304
- mode: str) -> Tuple[str, Optional[str], str, MidiFile, str, str, int]:
305
  try:
306
  prompt_to_use = custom_prompt if mode == "Exercise Prompt" else ""
307
- raw = query_mistral(prompt_to_use, instrument, level, key, time_signature, measures)
308
- parsed = safe_parse_json(raw)
309
-
310
  if not parsed:
311
- return "Invalid JSON format", None, str(tempo), None, "0", time_signature, 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
- # Scale to exact integer durations
314
- numerator = int(time_signature.split('/')[0])
315
- target_units = measures * numerator
316
- scaled = scale_json_durations(parsed, target_units)
317
-
318
  # Calculate total duration units
319
- total_duration = sum(d for _, d in scaled)
320
-
321
  # Generate MIDI and audio
322
- midi = json_to_midi(scaled, instrument, tempo, time_signature, measures)
323
  mp3_path, real_duration = midi_to_mp3(midi, instrument)
324
-
325
- return (json.dumps(scaled, indent=2), mp3_path, str(tempo),
326
- midi, f"{real_duration:.2f} seconds", time_signature, total_duration)
327
  except Exception as e:
328
  return f"Error: {str(e)}", None, str(tempo), None, "0", time_signature, 0
329
 
330
  # -----------------------------------------------------------------------------
331
- # 13. AI chat assistant
332
  # -----------------------------------------------------------------------------
333
  def handle_chat(message: str, history: List, instrument: str, level: str):
334
  if not message.strip():
335
  return "", history
336
-
337
- messages = [{"role": "system",
338
- "content": f"You are a {instrument} teacher for {level} students."}]
339
-
340
  for user_msg, assistant_msg in history:
341
  messages.append({"role": "user", "content": user_msg})
342
  messages.append({"role": "assistant", "content": assistant_msg})
343
-
344
  messages.append({"role": "user", "content": message})
345
-
346
- headers = {
347
- "Authorization": f"Bearer {MISTRAL_API_KEY}",
348
- "Content-Type": "application/json"
349
- }
350
- payload = {
351
- "model": "mistral-medium",
352
- "messages": messages,
353
- "temperature": 0.7,
354
- "max_tokens": 500
355
- }
356
-
357
  try:
358
  response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
359
  response.raise_for_status()
@@ -365,151 +635,545 @@ def handle_chat(message: str, history: List, instrument: str, level: str):
365
  return "", history
366
 
367
  # -----------------------------------------------------------------------------
368
- # 14. Gradio user interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  # -----------------------------------------------------------------------------
370
  def create_ui() -> gr.Blocks:
371
  with gr.Blocks(title="Adaptive Music Exercise Generator", theme="soft") as demo:
372
- gr.Markdown("# 🎼 Adaptive Music Exercise Generator (Integer Durations)")
373
  current_midi = gr.State(None)
374
  current_exercise = gr.State("")
 
375
 
376
- mode = gr.Radio(["Exercise Parameters","Exercise Prompt"],
377
- value="Exercise Parameters", label="Generation Mode")
378
-
379
  with gr.Row():
380
  with gr.Column(scale=1):
381
  with gr.Group(visible=True) as params_group:
382
  gr.Markdown("### Exercise Parameters")
383
- instrument = gr.Dropdown(
384
- ["Trumpet", "Piano", "Violin", "Clarinet", "Flute"],
385
- value="Trumpet", label="Instrument"
386
- )
387
- level = gr.Radio(
388
- ["Beginner", "Intermediate", "Advanced"],
389
- value="Intermediate", label="Difficulty Level"
390
- )
391
- key = gr.Dropdown(
392
- ["C Major", "G Major", "D Major", "F Major",
393
- "Bb Major", "A Minor", "E Minor"],
394
- value="C Major", label="Key Signature"
395
- )
396
- time_signature = gr.Dropdown(
397
- ["3/4", "4/4"], value="4/4", label="Time Signature"
398
- )
399
- measures = gr.Radio(
400
- [4, 8], value=4, label="Length (measures)"
401
- )
402
-
403
  with gr.Group(visible=False) as prompt_group:
404
  gr.Markdown("### Exercise Prompt")
405
- custom_prompt = gr.Textbox(
406
- "",
407
- label="Describe your exercise (e.g. 'Jazz trumpet exercise with syncopation')",
408
- lines=3
409
- )
410
- measures_prompt = gr.Radio(
411
- [4, 8], value=4, label="Length (measures)"
412
- )
413
-
414
  generate_btn = gr.Button("Generate Exercise", variant="primary")
415
-
416
  with gr.Column(scale=2):
417
  with gr.Tabs():
418
  with gr.TabItem("Exercise Player"):
419
- audio_output = gr.Audio(label="Generated Exercise",
420
- autoplay=True, type="filepath")
421
- bpm_display = gr.Textbox(label="Tempo (BPM)")
422
- time_sig_display = gr.Textbox(label="Time Signature")
423
- duration_display = gr.Textbox(label="Audio Duration",
424
- interactive=False)
425
-
 
 
 
 
 
 
 
 
 
426
  with gr.TabItem("Exercise Data"):
427
- json_output = gr.Code(
428
- label="JSON Representation",
429
- language="json",
430
- interactive=True
431
- )
432
- # NEW: Duration sum display
433
  duration_sum = gr.Number(
434
- label="Total Duration Units (1 unit = quarter note)",
435
  interactive=False,
436
  precision=0
437
  )
438
-
 
 
 
 
439
  with gr.TabItem("MIDI Export"):
440
  midi_output = gr.File(label="MIDI File")
441
  download_midi = gr.Button("Generate MIDI File")
442
-
 
 
 
 
 
443
  with gr.TabItem("AI Chat"):
444
- chat_history = gr.Chatbot(label="Practice Assistant",
445
- height=400)
446
- chat_message = gr.Textbox(
447
- label="Ask about technique, theory, or practice strategies"
448
- )
449
  send_chat_btn = gr.Button("Send")
450
-
451
- # UI visibility toggling
452
  mode.change(
453
  fn=lambda m: {
454
  params_group: gr.update(visible=(m == "Exercise Parameters")),
455
  prompt_group: gr.update(visible=(m == "Exercise Prompt")),
456
  },
457
- inputs=[mode],
458
- outputs=[params_group, prompt_group]
459
  )
460
-
461
- # Generate exercise handler
462
  def generate_caller(mode_val, instrument_val, level_val, key_val,
463
- time_sig_val, measures_val, prompt_val, measures_prompt_val):
464
  real_measures = measures_prompt_val if mode_val == "Exercise Prompt" else measures_val
465
- fixed_tempo = 60 # Fixed tempo for simplicity
466
- return generate_exercise(
467
- instrument_val, level_val, key_val, fixed_tempo,
468
- time_sig_val, real_measures, prompt_val, mode_val
469
  )
470
-
 
 
 
 
 
 
 
 
471
  generate_btn.click(
472
  fn=generate_caller,
473
- inputs=[mode, instrument, level, key, time_signature,
474
- measures, custom_prompt, measures_prompt],
475
- outputs=[json_output, audio_output, bpm_display,
476
- current_midi, duration_display, time_sig_display, duration_sum]
477
  )
478
 
479
- # MIDI export handler
480
- def save_midi(json_data, instr, time_sig):
481
- parsed = safe_parse_json(json_data)
482
- if not parsed:
483
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
 
485
- numerator = int(time_sig.split('/')[0])
486
- target_units = sum(int(d) for _, d in parsed)
487
- measures_est = max(1, round(target_units / numerator))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
 
489
- scaled = scale_json_durations(parsed, measures_est * numerator)
490
- midi_obj = json_to_midi(scaled, instr, 60, time_sig, measures_est)
491
 
492
- midi_path = os.path.join("static", "exercise.mid")
493
- midi_obj.save(midi_path)
494
- return midi_path
 
 
 
495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
  download_midi.click(
497
  fn=save_midi,
498
- inputs=[json_output, instrument, time_signature],
499
- outputs=[midi_output]
500
  )
501
-
502
- # Chat handler
503
  send_chat_btn.click(
504
  fn=handle_chat,
505
  inputs=[chat_message, chat_history, instrument, level],
506
- outputs=[chat_message, chat_history]
507
  )
508
-
509
  return demo
510
 
511
  # -----------------------------------------------------------------------------
512
- # 15. Entry point
513
  # -----------------------------------------------------------------------------
514
  if __name__ == "__main__":
515
  demo = create_ui()
 
1
  """
2
+ Adaptive Music Exercise Generator (Strict Duration Enforcement)
3
+ ==============================================================
4
+ Generates custom musical exercises with LLM, perfectly fit to user-specified number of measures
5
+ AND time signature, guaranteeing exact durations in MIDI and in the UI!
 
 
6
  Major updates:
7
+ - Changed base duration unit from 16th notes to 8th notes (1 unit = 8th note)
8
+ - Updated all calculations and prompts to use new duration system
9
+ - Duration sum display now shows total in 8th notes
10
+ - Maintained all original functionality
11
+ - Added cumulative duration tracking
12
+ - Enforced JSON output format with note, duration, cumulative_duration
13
+ - Enhanced rest handling and JSON parsing
14
+ - Fixed JSON parsing errors for 8-measure exercises
15
+ - Added robust error handling for MIDI generation
16
  """
17
 
18
  # -----------------------------------------------------------------------------
19
+ # 1. Runtime-time package installation (for fresh containers/Colab/etc)
20
  # -----------------------------------------------------------------------------
21
  import sys
22
  import subprocess
 
32
 
33
  install([
34
  "mido", "midi2audio", "pydub", "gradio",
35
+ "requests", "numpy", "matplotlib", "librosa", "scipy",
36
+ "uuid", "datetime"
37
  ])
38
 
39
  # -----------------------------------------------------------------------------
 
46
  import mido
47
  from mido import Message, MidiFile, MidiTrack, MetaMessage
48
  import re
49
+ from io import BytesIO
 
 
50
  from midi2audio import FluidSynth
51
  from pydub import AudioSegment
52
  import gradio as gr
53
+ import numpy as np
54
+ import matplotlib.pyplot as plt
55
+ import librosa
56
+ from scipy.io import wavfile
57
+ import os
58
+ import subprocess as sp
59
+ import base64
60
+ import shutil
61
+ import ast
62
+ import uuid
63
+ from datetime import datetime
64
+ import time
65
 
66
  # -----------------------------------------------------------------------------
67
+ # 3. Configuration & constants (UPDATED TO USE 8TH NOTES)
68
  # -----------------------------------------------------------------------------
69
  MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions"
70
+ MISTRAL_API_KEY = "yQdfM8MLbX9uhInQ7id4iUTwN4h4pDLX" # ← Replace with your key!
71
 
72
  SOUNDFONT_URLS = {
73
  "Trumpet": "https://github.com/FluidSynth/fluidsynth/raw/master/sf2/Trumpet.sf2",
 
77
  "Flute": "https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2",
78
  }
79
 
80
+ SAMPLE_RATE = 44100 # Hz
81
+ TICKS_PER_BEAT = 480 # Standard MIDI resolution
82
+ TICKS_PER_8TH = TICKS_PER_BEAT // 2 # 240 ticks per 8th note (UPDATED)
83
 
84
  if not os.path.exists('/usr/bin/fluidsynth'):
85
  try:
 
88
  print("Could not install FluidSynth automatically. Please install it manually.")
89
 
90
  os.makedirs("static", exist_ok=True)
91
+ os.makedirs("temp_audio", exist_ok=True)
92
+ os.makedirs("saved_exercises", exist_ok=True)
93
 
94
  # -----------------------------------------------------------------------------
95
+ # 4. Music theory helpers (note names β†”οΈŽ MIDI numbers) - ENHANCED REST HANDLING
96
  # -----------------------------------------------------------------------------
97
+ NOTE_MAP: Dict[str, int] = {
98
+ "C": 0, "C#": 1, "DB": 1,
99
+ "D": 2, "D#": 3, "EB": 3,
100
+ "E": 4, "F": 5, "F#": 6, "GB": 6,
101
+ "G": 7, "G#": 8, "AB": 8,
102
+ "A": 9, "A#": 10, "BB": 10,
103
+ "B": 11,
104
+ }
105
+
106
+ REST_INDICATORS = ["rest", "r", "Rest", "R", "P", "p", "pause"]
107
 
108
  INSTRUMENT_PROGRAMS: Dict[str, int] = {
109
+ "Piano": 0, "Trumpet": 56, "Violin": 40,
110
+ "Clarinet": 71, "Flute": 73,
111
  }
112
 
113
+ def is_rest(note: str) -> bool:
114
+ """Check if a note string represents a rest."""
115
+ return note.strip().lower() in [r.lower() for r in REST_INDICATORS]
 
 
 
 
 
116
 
117
  def note_name_to_midi(note: str) -> int:
118
+ if is_rest(note):
119
+ return -1 # Special value for rests
120
+
121
+ # Allow both scientific (C4) and Helmholtz (C') notation
122
+ match = re.match(r"([A-Ga-g][#b]?)(\'*)(\d?)", note)
123
  if not match:
124
  raise ValueError(f"Invalid note: {note}")
125
+
126
+ pitch, apostrophes, octave = match.groups()
127
+ pitch = pitch.upper().replace('b', 'B')
128
+
129
+ # Handle Helmholtz notation (C' = C5, C'' = C6, etc)
130
+ octave_num = 4
131
+ if octave:
132
+ octave_num = int(octave)
133
+ elif apostrophes:
134
+ octave_num = 5 + len(apostrophes)
135
+
136
  if pitch not in NOTE_MAP:
137
  raise ValueError(f"Invalid pitch: {pitch}")
138
+
139
+ return NOTE_MAP[pitch] + (octave_num + 1) * 12
140
+
141
+ def midi_to_note_name(midi_num: int) -> str:
142
+ if midi_num == -1:
143
+ return "Rest"
144
+ notes = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
145
+ octave = (midi_num // 12) - 1
146
+ return f"{notes[midi_num % 12]}{octave}"
147
 
148
  # -----------------------------------------------------------------------------
149
+ # 5. Duration scaling: guarantee the output sums to requested total (using integers)
150
  # -----------------------------------------------------------------------------
151
  def scale_json_durations(json_data, target_units: int) -> list:
152
+ """Scales durations so that their sum is exactly target_units (8th notes)."""
153
+ durations = [int(d) for _, d in json_data]
154
+ total = sum(durations)
155
+ if total == 0:
156
+ return json_data
157
+
158
+ # Calculate proportional scaling with integer arithmetic
159
+ scaled = []
160
+ remainder = target_units
161
+ for i, (note, d) in enumerate(json_data):
162
+ if i < len(json_data) - 1:
163
+ # Proportional allocation
164
+ portion = max(1, round(d * target_units / total))
165
+ scaled.append([note, portion])
166
+ remainder -= portion
167
+ else:
168
+ # Last note gets all remaining units
169
+ scaled.append([note, max(1, remainder)])
170
+
171
+ return scaled
172
 
173
  # -----------------------------------------------------------------------------
174
+ # 6. MIDI from scaled JSON (using integer durations) - UPDATED REST HANDLING
175
  # -----------------------------------------------------------------------------
176
+ def json_to_midi(json_data: list, instrument: str, tempo: int, time_signature: str, measures: int, key: str = "C Major") -> MidiFile:
 
177
  mid = MidiFile(ticks_per_beat=TICKS_PER_BEAT)
178
  track = MidiTrack(); mid.tracks.append(track)
179
  program = INSTRUMENT_PROGRAMS.get(instrument, 56)
180
+ numerator, denominator = map(int, time_signature.split('/'))
181
 
182
+ # Add time signature meta message
183
+ track.append(MetaMessage('time_signature', numerator=numerator,
184
+ denominator=denominator, time=0))
185
+ # Add tempo meta message
186
  track.append(MetaMessage('set_tempo', tempo=mido.bpm2tempo(tempo), time=0))
187
+
188
+ # Add key signature meta message based on the key
189
+ # For MIDI key signatures, the key parameter expects a string like 'C', 'F#m', etc.
190
+ key_map = {
191
+ "C Major": "C",
192
+ "G Major": "G",
193
+ "D Major": "D",
194
+ "F Major": "F",
195
+ "Bb Major": "Bb",
196
+ "A Minor": "Am",
197
+ "E Minor": "Em",
198
+ }
199
+
200
+ # Use the provided key or default to C major if key not found
201
+ midi_key = key_map.get(key, "C")
202
+ # The 'key' parameter in MetaMessage expects a string like 'C', 'F#m', etc.
203
+ track.append(MetaMessage('key_signature', key=midi_key, time=0))
204
+
205
+ # Set instrument program
206
  track.append(Message('program_change', program=program, time=0))
207
 
208
+ # Accumulator for rest durations
209
+ accumulated_rest = 0
210
+
211
+ for note_item in json_data:
212
+ try:
213
+ # Handle both formats: [note, duration] and {note, duration}
214
+ if isinstance(note_item, list) and len(note_item) == 2:
215
+ note_name, duration_units = note_item
216
+ elif isinstance(note_item, dict):
217
+ note_name = note_item["note"]
218
+ duration_units = note_item["duration"]
219
+ else:
220
+ print(f"Unsupported note format: {note_item}")
221
+ continue
222
+
223
+ ticks = int(duration_units * TICKS_PER_8TH)
224
+ ticks = max(ticks, 1)
225
+
226
+ if is_rest(note_name):
227
+ # Accumulate rest duration
228
+ accumulated_rest += ticks
229
+ else:
230
+ # Process any accumulated rest first
231
+ if accumulated_rest > 0:
232
+ # Add rest by creating a silent note (velocity 0) that won't be heard
233
+ # Or just skip and use accumulated_rest in timing
234
+ # We'll just add the time to the next note
235
+ track.append(Message('note_on', note=0, velocity=0, time=accumulated_rest))
236
+ track.append(Message('note_off', note=0, velocity=0, time=0))
237
+ accumulated_rest = 0
238
+
239
+ # Process actual note
240
+ note_num = note_name_to_midi(note_name)
241
+ velocity = random.randint(60, 100)
242
+ track.append(Message('note_on', note=note_num, velocity=velocity, time=0))
243
+ track.append(Message('note_off', note=note_num, velocity=velocity, time=ticks))
244
+ except Exception as e:
245
+ print(f"Error parsing note {note_item}: {e}")
246
+
247
+ # Handle trailing rest
248
+ if accumulated_rest > 0:
249
+ track.append(Message('note_on', note=0, velocity=0, time=accumulated_rest))
250
+ track.append(Message('note_off', note=0, velocity=0, time=0))
251
+
252
  return mid
253
 
254
  # -----------------------------------------------------------------------------
 
271
  wav_path = mid_file.name.replace(".mid", ".wav")
272
  mp3_path = mid_file.name.replace(".mid", ".mp3")
273
  sf2_path = get_soundfont(instrument)
 
274
  try:
275
  sp.run([
276
  'fluidsynth', '-ni', sf2_path, mid_file.name,
277
+ '-F', wav_path, '-r', '44100', '-g', '1.0'
278
  ], check=True, capture_output=True)
279
  except Exception:
280
+ fs = FluidSynth(sf2_path, sample_rate=44100, gain=1.0)
281
  fs.midi_to_audio(mid_file.name, wav_path)
282
+ try:
283
+ sound = AudioSegment.from_wav(wav_path)
284
+ if instrument == "Trumpet":
285
+ sound = sound.high_pass_filter(200)
286
+ elif instrument == "Violin":
287
+ sound = sound.low_pass_filter(5000)
288
+ sound.export(mp3_path, format="mp3")
289
+ static_mp3_path = os.path.join('static', os.path.basename(mp3_path))
290
+ shutil.move(mp3_path, static_mp3_path)
291
+ return static_mp3_path, sound.duration_seconds
292
+ finally:
293
+ for f in [mid_file.name, wav_path]:
294
+ try:
295
+ os.remove(f)
296
+ except FileNotFoundError:
297
+ pass
298
 
299
  # -----------------------------------------------------------------------------
300
+ # 8. Prompt engineering for variety (using integer durations) - UPDATED DURATION SYSTEM
301
  # -----------------------------------------------------------------------------
302
  def get_fallback_exercise(instrument: str, level: str, key: str,
303
  time_sig: str, measures: int) -> str:
304
+ key_notes = {
305
+ "C Major": ["C4", "D4", "E4", "F4", "G4", "A4", "B4"],
306
+ "G Major": ["G3", "A3", "B3", "C4", "D4", "E4", "F#4"],
307
+ "D Major": ["D4", "E4", "F#4", "G4", "A4", "B4", "C#5"],
308
+ "F Major": ["F3", "G3", "A3", "Bb3", "C4", "D4", "E4"],
309
+ "Bb Major": ["Bb3", "C4", "D4", "Eb4", "F4", "G4", "A4"],
310
+ "A Minor": ["A3", "B3", "C4", "D4", "E4", "F4", "G4"],
311
+ "E Minor": ["E3", "F#3", "G3", "A3", "B3", "C4", "D4"],
312
  }
313
+
314
+ # Get fundamental note from key signature
315
+ fundamental_note = key.split()[0] # Gets 'C' from 'C Major' or 'A' from 'A Minor'
316
+ is_major = "Major" in key
317
+
318
+ # Get notes for the key
319
+ notes = key_notes.get(key, key_notes["C Major"])
320
+
321
+ # Find fundamental note with octave for ending
322
+ fundamental_with_octave = None
323
+ for note in notes:
324
+ if note.startswith(fundamental_note):
325
+ fundamental_with_octave = note
326
+ break
327
+
328
+ # If not found, use the first note (should not happen with our key definitions)
329
+ if not fundamental_with_octave:
330
+ fundamental_with_octave = notes[0]
331
+
332
+ numerator, denominator = map(int, time_sig.split('/'))
333
+
334
+ # Calculate units based on 8th notes
335
+ units_per_measure = numerator * (8 // denominator)
336
+ target_units = measures * units_per_measure
337
+
338
+ # Create a rhythm pattern based on time signature
339
+ if numerator == 3:
340
+ rhythm = [2, 1, 2, 1, 2] # 3/4 pattern
341
+ else:
342
+ rhythm = [2, 2, 1, 1, 2, 2] # 4/4 pattern
343
+
344
+ # Build exercise
345
+ result = []
346
+ cumulative = 0
347
+ current_units = 0
348
+
349
+ # Reserve at least 2 units for the final note
350
+ final_note_duration = min(4, max(2, rhythm[0])) # Between 2 and 4 units
351
+ available_units = target_units - final_note_duration
352
+
353
+ # Generate notes until we reach the available units
354
+ while current_units < available_units:
355
+ # Avoid minor 7th in major keys
356
+ if is_major:
357
+ # Filter out minor 7th notes (e.g., Bb in C major)
358
+ available_notes = [n for n in notes if not (n.startswith("Bb") and key == "C Major") and
359
+ not (n.startswith("F") and key == "G Major") and
360
+ not (n.startswith("C") and key == "D Major") and
361
+ not (n.startswith("Eb") and key == "F Major") and
362
+ not (n.startswith("Ab") and key == "Bb Major")]
363
+ else:
364
+ available_notes = notes
365
+
366
+ note = random.choice(available_notes)
367
+ dur = random.choice(rhythm)
368
+
369
+ # Don't exceed available units
370
+ if current_units + dur > available_units:
371
+ dur = available_units - current_units
372
+ if dur <= 0:
373
+ break
374
+
375
+ cumulative += dur
376
+ current_units += dur
377
+ result.append({
378
+ "note": note,
379
+ "duration": dur,
380
+ "cumulative_duration": cumulative
381
+ })
382
+
383
+ # Add the final note (fundamental of the key)
384
+ final_duration = target_units - current_units
385
+ if final_duration > 0:
386
+ cumulative += final_duration
387
+ result.append({
388
+ "note": fundamental_with_octave,
389
+ "duration": final_duration,
390
+ "cumulative_duration": cumulative
391
+ })
392
+
393
+ return json.dumps(result)
394
 
 
 
 
395
  def get_style_based_on_level(level: str) -> str:
396
  styles = {
397
  "Beginner": ["simple", "legato", "stepwise"],
 
409
  return random.choice(techniques.get(level, ["with slurs"]))
410
 
411
  # -----------------------------------------------------------------------------
412
+ # 9. Mistral API: query, fallback on errors - UPDATED DURATION SYSTEM
413
  # -----------------------------------------------------------------------------
414
  def query_mistral(prompt: str, instrument: str, level: str, key: str,
415
  time_sig: str, measures: int) -> str:
 
417
  "Authorization": f"Bearer {MISTRAL_API_KEY}",
418
  "Content-Type": "application/json",
419
  }
420
+ numerator, denominator = map(int, time_sig.split('/'))
 
421
 
422
+ # UPDATED: Calculate total required 8th notes
423
+ units_per_measure = numerator * (8 // denominator)
424
+ required_total = measures * units_per_measure
425
+
426
+ # UPDATED: Duration explanation in prompt
427
+ duration_constraint = (
428
+ f"Sum of all durations MUST BE EXACTLY {required_total} units (8th notes). "
429
+ f"Each integer duration represents an 8th note (1=8th, 2=quarter, 4=half, 8=whole). "
430
+ f"If it doesn't match, the exercise is invalid."
431
  )
432
+ system_prompt = (
433
+ f"You are an expert music teacher specializing in {instrument.lower()}. "
434
+ "Create customized exercises using INTEGER durations representing 8th notes."
435
+ )
436
+
437
  if prompt.strip():
438
+ user_prompt = (
439
+ f"{prompt} {duration_constraint} Output ONLY a JSON array of objects with "
440
+ "the following structure: [{{'note': string, 'duration': integer, 'cumulative_duration': integer}}]"
441
+ )
442
  else:
 
443
  style = get_style_based_on_level(level)
444
+ technique = get_technique_based_on_level(level)
445
+ # Extract fundamental note from key signature
446
+ fundamental_note = key.split()[0] # Gets 'C' from 'C Major' or 'A' from 'A Minor'
447
+ is_major = "Major" in key
448
+
449
+ # Create additional musical constraints
450
+ key_constraints = (
451
+ f"The exercise MUST end on the fundamental note of the key ({fundamental_note}). "
452
+ f"{'' if not is_major else 'For this major key, avoid using the minor 7th degree.'}"
453
+ )
454
+
455
  user_prompt = (
456
+ f"Create a {style} {instrument.lower()} exercise in {key} with {time_sig} time signature "
457
+ f"{technique} for a {level.lower()} player. {duration_constraint} {key_constraints} "
458
+ "Output ONLY a JSON array of objects with the following structure: "
459
+ "[{{'note': string, 'duration': integer, 'cumulative_duration': integer}}] "
460
+ "Use standard note names (e.g., \"Bb4\", \"F#5\"). Monophonic only. "
461
+ "Durations: 1=8th, 2=quarter, 4=half, 8=whole. "
462
+ "Sum must be exactly as specified. ONLY output the JSON array. No prose."
463
  )
464
+
465
  payload = {
466
  "model": "mistral-medium",
467
  "messages": [
468
+ {"role": "system", "content": system_prompt},
469
+ {"role": "user", "content": user_prompt},
470
  ],
471
+ "temperature": 0.7 if level == "Advanced" else 0.5,
472
+ "max_tokens": 1000,
473
+ "top_p": 0.95,
474
+ "frequency_penalty": 0.2,
475
+ "presence_penalty": 0.2,
476
  }
477
+
478
  try:
479
  response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
480
  response.raise_for_status()
481
+ content = response.json()["choices"][0]["message"]["content"]
482
+ return content.replace("```json","").replace("```","").strip()
483
  except Exception as e:
484
  print(f"Error querying Mistral API: {e}")
485
  return get_fallback_exercise(instrument, level, key, time_sig, measures)
486
 
487
  # -----------------------------------------------------------------------------
488
+ # 10. Robust JSON parsing for LLM outputs - ENHANCED PARSING
489
  # -----------------------------------------------------------------------------
490
  def safe_parse_json(text: str) -> Optional[list]:
491
  try:
492
+ text = text.strip().replace("'", '"')
493
+
494
+ # Find JSON array in the text
495
+ start_idx = text.find('[')
496
+ end_idx = text.rfind(']')
497
+ if start_idx == -1 or end_idx == -1:
498
+ return None
499
+
500
+ json_str = text[start_idx:end_idx+1]
501
+
502
+ # Fix common JSON issues
503
+ json_str = re.sub(r',\s*([}\]])', r'\1', json_str) # Trailing commas
504
+ json_str = re.sub(r'{\s*(\w+)\s*:', r'{"\1":', json_str) # Unquoted keys
505
+ json_str = re.sub(r':\s*([a-zA-Z_][a-zA-Z0-9_]*)(\s*[,}])', r':"\1"\2', json_str) # Unquoted strings
506
+
507
+ parsed = json.loads(json_str)
508
+
509
+ # Normalize keys to 'note' and 'duration'
510
+ normalized = []
511
+ for item in parsed:
512
+ if isinstance(item, dict):
513
+ # Find note value - accept multiple keys
514
+ note_val = None
515
+ for key in ['note', 'pitch', 'nota', 'ton']:
516
+ if key in item:
517
+ note_val = str(item[key])
518
+ break
519
+
520
+ # Find duration value
521
+ dur_val = None
522
+ for key in ['duration', 'dur', 'length', 'value']:
523
+ if key in item:
524
+ try:
525
+ dur_val = int(item[key])
526
+ except (TypeError, ValueError):
527
+ pass
528
+
529
+ if note_val is not None and dur_val is not None:
530
+ normalized.append({"note": note_val, "duration": dur_val})
531
+
532
+ return normalized if normalized else None
533
+
534
  except Exception as e:
535
  print(f"JSON parsing error: {e}\nRaw text: {text}")
536
  return None
537
 
538
  # -----------------------------------------------------------------------------
539
+ # 11. Main orchestration: talk to API, *scale durations*, build MIDI, UI values - UPDATED
540
  # -----------------------------------------------------------------------------
541
+ def generate_exercise(instrument: str, level: str, key: str, tempo: int, time_signature: str,
542
+ measures: int, custom_prompt: str, mode: str) -> Tuple[str, Optional[str], str, MidiFile, str, str, int]:
 
543
  try:
544
  prompt_to_use = custom_prompt if mode == "Exercise Prompt" else ""
545
+ output = query_mistral(prompt_to_use, instrument, level, key, time_signature, measures)
546
+ parsed = safe_parse_json(output)
 
547
  if not parsed:
548
+ print("Primary parsing failed, using fallback")
549
+ fallback_str = get_fallback_exercise(instrument, level, key, time_signature, measures)
550
+ parsed = safe_parse_json(fallback_str)
551
+ if not parsed:
552
+ print("Fallback parsing failed, using ultimate fallback")
553
+ # Ultimate fallback: simple scale based on selected key
554
+ key_notes = {
555
+ "C Major": ["C4", "D4", "E4", "F4", "G4", "A4", "B4", "C5"],
556
+ "G Major": ["G3", "A3", "B3", "C4", "D4", "E4", "F#4", "G4"],
557
+ "D Major": ["D4", "E4", "F#4", "G4", "A4", "B4", "C#5", "D5"],
558
+ "F Major": ["F3", "G3", "A3", "Bb3", "C4", "D4", "E4", "F4"],
559
+ "Bb Major": ["Bb3", "C4", "D4", "Eb4", "F4", "G4", "A4", "Bb4"],
560
+ "A Minor": ["A3", "B3", "C4", "D4", "E4", "F4", "G4", "A4"],
561
+ "E Minor": ["E3", "F#3", "G3", "A3", "B3", "C4", "D4", "E4"],
562
+ }
563
+ notes = key_notes.get(key, key_notes["C Major"])
564
+ numerator, denominator = map(int, time_signature.split('/'))
565
+ units_per_measure = numerator * (8 // denominator)
566
+ target_units = measures * units_per_measure
567
+ note_duration = max(1, target_units // len(notes))
568
+ parsed = [{"note": n, "duration": note_duration} for n in notes]
569
+ # Adjust last note to match total duration
570
+ total = sum(item["duration"] for item in parsed)
571
+ if total < target_units:
572
+ parsed[-1]["duration"] += target_units - total
573
+ elif total > target_units:
574
+ parsed[-1]["duration"] -= total - target_units
575
+
576
+ # Calculate total required 8th notes (UPDATED)
577
+ numerator, denominator = map(int, time_signature.split('/'))
578
+ units_per_measure = numerator * (8 // denominator)
579
+ total_units = measures * units_per_measure
580
+
581
+ # Convert to old format for scaling
582
+ old_format = []
583
+ for item in parsed:
584
+ if isinstance(item, dict):
585
+ old_format.append([item["note"], item["duration"]])
586
+ else:
587
+ old_format.append(item)
588
+
589
+ # Strict scaling
590
+ parsed_scaled_old = scale_json_durations(old_format, total_units)
591
+
592
+ # Convert back to new format with cumulative durations
593
+ cumulative = 0
594
+ parsed_scaled = []
595
+ for note, dur in parsed_scaled_old:
596
+ cumulative += dur
597
+ parsed_scaled.append({
598
+ "note": note,
599
+ "duration": dur,
600
+ "cumulative_duration": cumulative
601
+ })
602
 
 
 
 
 
 
603
  # Calculate total duration units
604
+ total_duration = cumulative
605
+
606
  # Generate MIDI and audio
607
+ midi = json_to_midi(parsed_scaled, instrument, tempo, time_signature, measures, key)
608
  mp3_path, real_duration = midi_to_mp3(midi, instrument)
609
+ output_json_str = json.dumps(parsed_scaled, indent=2)
610
+ return output_json_str, mp3_path, str(tempo), midi, f"{real_duration:.2f} seconds", time_signature, total_duration
 
611
  except Exception as e:
612
  return f"Error: {str(e)}", None, str(tempo), None, "0", time_signature, 0
613
 
614
  # -----------------------------------------------------------------------------
615
+ # 12. Simple AI chat assistant (optional, shares LLM)
616
  # -----------------------------------------------------------------------------
617
  def handle_chat(message: str, history: List, instrument: str, level: str):
618
  if not message.strip():
619
  return "", history
620
+ messages = [{"role": "system", "content": f"You are a {instrument} teacher for {level} students."}]
 
 
 
621
  for user_msg, assistant_msg in history:
622
  messages.append({"role": "user", "content": user_msg})
623
  messages.append({"role": "assistant", "content": assistant_msg})
 
624
  messages.append({"role": "user", "content": message})
625
+ headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"}
626
+ payload = {"model": "mistral-medium", "messages": messages, "temperature": 0.7, "max_tokens": 500}
 
 
 
 
 
 
 
 
 
 
627
  try:
628
  response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
629
  response.raise_for_status()
 
635
  return "", history
636
 
637
  # -----------------------------------------------------------------------------
638
+ # 13. New features: Visualization, Metronome, and Exercise Library
639
+ # -----------------------------------------------------------------------------
640
+
641
+ # Visualization function to create a piano roll representation of the exercise
642
+ def create_visualization(json_data, time_sig):
643
+ try:
644
+ if not json_data or "Error" in json_data:
645
+ return None
646
+
647
+ parsed = json.loads(json_data)
648
+ if not isinstance(parsed, list) or len(parsed) == 0:
649
+ return None
650
+
651
+ # Extract notes and durations
652
+ notes = []
653
+ durations = []
654
+ for item in parsed:
655
+ if isinstance(item, dict) and "note" in item and "duration" in item:
656
+ note_name = item["note"]
657
+ if not is_rest(note_name):
658
+ try:
659
+ midi_note = note_name_to_midi(note_name)
660
+ notes.append(midi_note)
661
+ durations.append(item["duration"])
662
+ except ValueError:
663
+ notes.append(60) # Default to middle C if parsing fails
664
+ durations.append(item["duration"])
665
+ else:
666
+ notes.append(None) # Represent rest
667
+ durations.append(item["duration"])
668
+
669
+ # Create piano roll visualization
670
+ fig, ax = plt.subplots(figsize=(12, 6))
671
+
672
+ # Calculate time positions
673
+ time_positions = [0]
674
+ for dur in durations[:-1]:
675
+ time_positions.append(time_positions[-1] + dur)
676
+
677
+ # Plot notes as rectangles
678
+ for i, (note, dur, pos) in enumerate(zip(notes, durations, time_positions)):
679
+ if note is not None: # Skip rests
680
+ rect = plt.Rectangle((pos, note-0.4), dur, 0.8, color='blue', alpha=0.7)
681
+ ax.add_patch(rect)
682
+ # Add note name
683
+ ax.text(pos + dur/2, note+0.5, midi_to_note_name(note),
684
+ ha='center', va='bottom', fontsize=8)
685
+
686
+ # Add measure lines
687
+ numerator, denominator = map(int, time_sig.split('/'))
688
+ units_per_measure = numerator * (8 // denominator)
689
+ max_time = time_positions[-1] + durations[-1]
690
+ for measure in range(1, int(max_time / units_per_measure) + 1):
691
+ measure_pos = measure * units_per_measure
692
+ if measure_pos <= max_time:
693
+ ax.axvline(x=measure_pos, color='gray', linestyle='--', alpha=0.5)
694
+
695
+ # Set axis limits and labels
696
+ ax.set_ylim(min(notes) - 5 if None not in notes else 55,
697
+ max(notes) + 5 if None not in notes else 75)
698
+ ax.set_xlim(0, max_time)
699
+ ax.set_ylabel('MIDI Note')
700
+ ax.set_xlabel('Time (8th note units)')
701
+ ax.set_title('Exercise Visualization')
702
+
703
+ # Add piano keyboard on y-axis
704
+ ax.set_yticks([60, 62, 64, 65, 67, 69, 71, 72]) # C4 to C5
705
+ ax.set_yticklabels(['C4', 'D4', 'E4', 'F4', 'G4', 'A4', 'B4', 'C5'])
706
+ ax.grid(True, axis='y', alpha=0.3)
707
+
708
+ # Save figure to temporary file
709
+ temp_img_path = os.path.join('static', f'visualization_{uuid.uuid4().hex}.png')
710
+ plt.tight_layout()
711
+ plt.savefig(temp_img_path)
712
+ plt.close()
713
+
714
+ return temp_img_path
715
+ except Exception as e:
716
+ print(f"Error creating visualization: {e}")
717
+ return None
718
+
719
+ # Metronome function
720
+ def create_metronome_audio(tempo, time_sig, measures):
721
+ try:
722
+ numerator, denominator = map(int, time_sig.split('/'))
723
+ # Create a MIDI file with metronome clicks
724
+ mid = MidiFile(ticks_per_beat=TICKS_PER_BEAT)
725
+ track = MidiTrack()
726
+ mid.tracks.append(track)
727
+
728
+ # Add time signature and tempo
729
+ track.append(MetaMessage('time_signature', numerator=numerator,
730
+ denominator=denominator, time=0))
731
+ track.append(MetaMessage('set_tempo', tempo=mido.bpm2tempo(int(tempo)), time=0))
732
+
733
+ # Calculate total beats
734
+ beats_per_measure = numerator
735
+ total_beats = beats_per_measure * measures
736
+
737
+ # Add metronome clicks (strong beat = note 77, weak beat = note 76)
738
+ for beat in range(total_beats):
739
+ # Strong beat on first beat of measure, weak beat otherwise
740
+ note_num = 77 if beat % beats_per_measure == 0 else 76
741
+ velocity = 100 if beat % beats_per_measure == 0 else 80
742
+
743
+ # Add note on (with time=0 for first beat)
744
+ if beat == 0:
745
+ track.append(Message('note_on', note=note_num, velocity=velocity, time=0))
746
+ else:
747
+ # Each beat is a quarter note (TICKS_PER_BEAT)
748
+ track.append(Message('note_on', note=note_num, velocity=velocity, time=TICKS_PER_BEAT))
749
+
750
+ # Short duration for click
751
+ track.append(Message('note_off', note=note_num, velocity=0, time=10))
752
+
753
+ # Save and convert to audio
754
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mid") as mid_file:
755
+ mid.save(mid_file.name)
756
+ wav_path = mid_file.name.replace(".mid", ".wav")
757
+ mp3_path = mid_file.name.replace(".mid", ".mp3")
758
+
759
+ # Use piano soundfont for metronome
760
+ sf2_path = get_soundfont("Piano")
761
+ try:
762
+ sp.run([
763
+ 'fluidsynth', '-ni', sf2_path, mid_file.name,
764
+ '-F', wav_path, '-r', '44100', '-g', '1.0'
765
+ ], check=True, capture_output=True)
766
+ except Exception:
767
+ fs = FluidSynth(sf2_path, sample_rate=44100, gain=1.0)
768
+ fs.midi_to_audio(mid_file.name, wav_path)
769
+
770
+ # Convert to MP3
771
+ sound = AudioSegment.from_wav(wav_path)
772
+ sound.export(mp3_path, format="mp3")
773
+
774
+ # Move to static directory
775
+ static_mp3_path = os.path.join('static', f'metronome_{uuid.uuid4().hex}.mp3')
776
+ shutil.move(mp3_path, static_mp3_path)
777
+
778
+ # Clean up temporary files
779
+ for f in [mid_file.name, wav_path]:
780
+ try:
781
+ os.remove(f)
782
+ except FileNotFoundError:
783
+ pass
784
+
785
+ return static_mp3_path
786
+ except Exception as e:
787
+ print(f"Error creating metronome: {e}")
788
+ return None
789
+
790
+ # Function to save exercise to library
791
+ def save_exercise_to_library(json_data, instrument, level, key, time_sig, tempo, audio_path):
792
+ try:
793
+ if not json_data or "Error" in json_data or not audio_path:
794
+ return False, "Invalid exercise data"
795
+
796
+ # Create unique ID for exercise
797
+ exercise_id = uuid.uuid4().hex
798
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
799
+
800
+ # Copy audio file to permanent storage
801
+ if audio_path and os.path.exists(audio_path):
802
+ exercise_audio_path = os.path.join("temp_audio", f"exercise_{exercise_id}.mp3")
803
+ shutil.copy(audio_path, exercise_audio_path)
804
+ else:
805
+ exercise_audio_path = ""
806
+
807
+ # Create exercise metadata
808
+ exercise_data = {
809
+ "id": exercise_id,
810
+ "timestamp": timestamp,
811
+ "instrument": instrument,
812
+ "level": level,
813
+ "key": key,
814
+ "time_signature": time_sig,
815
+ "tempo": tempo,
816
+ "json_data": json_data,
817
+ "audio_path": exercise_audio_path
818
+ }
819
+
820
+ # Save to library file
821
+ library_file = os.path.join("saved_exercises", "library.json")
822
+
823
+ # Load existing library or create new one
824
+ if os.path.exists(library_file):
825
+ try:
826
+ with open(library_file, "r") as f:
827
+ library = json.load(f)
828
+ except json.JSONDecodeError:
829
+ library = {"exercises": []}
830
+ else:
831
+ library = {"exercises": []}
832
+
833
+ # Add new exercise
834
+ library["exercises"].append(exercise_data)
835
+
836
+ # Save updated library
837
+ with open(library_file, "w") as f:
838
+ json.dump(library, f, indent=2)
839
+
840
+ return True, f"Exercise saved to library with ID: {exercise_id}"
841
+ except Exception as e:
842
+ return False, f"Error saving exercise: {str(e)}"
843
+
844
+ # Function to load exercises from library
845
+ def load_exercises_from_library():
846
+ try:
847
+ library_file = os.path.join("saved_exercises", "library.json")
848
+ if not os.path.exists(library_file):
849
+ return []
850
+
851
+ with open(library_file, "r") as f:
852
+ library = json.load(f)
853
+
854
+ return library.get("exercises", [])
855
+ except Exception as e:
856
+ print(f"Error loading library: {e}")
857
+ return []
858
+
859
+ # Function to calculate difficulty rating
860
+ def calculate_difficulty_rating(json_data, level):
861
+ try:
862
+ if not json_data or "Error" in json_data:
863
+ return 0
864
+
865
+ parsed = json.loads(json_data)
866
+ if not isinstance(parsed, list) or len(parsed) == 0:
867
+ return 0
868
+
869
+ # Extract notes and durations
870
+ notes = []
871
+ durations = []
872
+ for item in parsed:
873
+ if isinstance(item, dict) and "note" in item and "duration" in item:
874
+ note_name = item["note"]
875
+ if not is_rest(note_name):
876
+ try:
877
+ midi_note = note_name_to_midi(note_name)
878
+ notes.append(midi_note)
879
+ durations.append(item["duration"])
880
+ except ValueError:
881
+ pass
882
+
883
+ if not notes:
884
+ return 0
885
+
886
+ # Calculate difficulty factors
887
+ # 1. Range (wider range = harder)
888
+ note_range = max(notes) - min(notes) if notes else 0
889
+ range_factor = min(note_range / 12, 1.0) # Normalize to octave
890
+
891
+ # 2. Rhythmic complexity (more varied durations = harder)
892
+ unique_durations = len(set(durations))
893
+ rhythm_factor = min(unique_durations / 4, 1.0) # Normalize
894
+
895
+ # 3. Interval jumps (larger jumps = harder)
896
+ jumps = [abs(notes[i] - notes[i-1]) for i in range(1, len(notes))]
897
+ avg_jump = sum(jumps) / len(jumps) if jumps else 0
898
+ jump_factor = min(avg_jump / 7, 1.0) # Normalize to perfect fifth
899
+
900
+ # 4. Speed factor (shorter durations = harder)
901
+ avg_duration = sum(durations) / len(durations) if durations else 0
902
+ speed_factor = min(2.0 / avg_duration if avg_duration > 0 else 1.0, 1.0) # Normalize
903
+
904
+ # Calculate base difficulty
905
+ base_difficulty = (range_factor * 0.25 +
906
+ rhythm_factor * 0.25 +
907
+ jump_factor * 0.25 +
908
+ speed_factor * 0.25)
909
+
910
+ # Apply level multiplier
911
+ level_multiplier = {
912
+ "Beginner": 0.7,
913
+ "Intermediate": 1.0,
914
+ "Advanced": 1.3
915
+ }.get(level, 1.0)
916
+
917
+ # Calculate final rating (1-10 scale)
918
+ rating = round(base_difficulty * level_multiplier * 10)
919
+ return max(1, min(rating, 10)) # Ensure between 1-10
920
+ except Exception as e:
921
+ print(f"Error calculating difficulty: {e}")
922
+ return 0
923
+
924
+ # -----------------------------------------------------------------------------
925
+ # 14. Gradio user interface definition (for humans!) - ENHANCED GUI
926
  # -----------------------------------------------------------------------------
927
  def create_ui() -> gr.Blocks:
928
  with gr.Blocks(title="Adaptive Music Exercise Generator", theme="soft") as demo:
929
+ gr.Markdown("# 🎼 Adaptive Music Exercise Generator")
930
  current_midi = gr.State(None)
931
  current_exercise = gr.State("")
932
+ current_audio_path = gr.State(None)
933
 
934
+ mode = gr.Radio(["Exercise Parameters","Exercise Prompt"], value="Exercise Parameters", label="Generation Mode")
 
 
935
  with gr.Row():
936
  with gr.Column(scale=1):
937
  with gr.Group(visible=True) as params_group:
938
  gr.Markdown("### Exercise Parameters")
939
+ instrument = gr.Dropdown([
940
+ "Trumpet", "Piano", "Violin", "Clarinet", "Flute",
941
+ ], value="Trumpet", label="Instrument")
942
+ level = gr.Radio([
943
+ "Beginner", "Intermediate", "Advanced",
944
+ ], value="Intermediate", label="Difficulty Level")
945
+ key = gr.Dropdown([
946
+ "C Major", "G Major", "D Major", "F Major", "Bb Major", "A Minor", "E Minor",
947
+ ], value="C Major", label="Key Signature")
948
+ time_signature = gr.Dropdown(["3/4", "4/4"], value="4/4", label="Time Signature")
949
+ measures = gr.Radio([4, 8, 12, 16], value=4, label="Length (measures)")
 
 
 
 
 
 
 
 
 
950
  with gr.Group(visible=False) as prompt_group:
951
  gr.Markdown("### Exercise Prompt")
952
+ custom_prompt = gr.Textbox("", label="Enter your custom prompt", lines=3)
953
+ measures_prompt = gr.Radio([4, 8, 12, 16], value=4, label="Length (measures)")
 
 
 
 
 
 
 
954
  generate_btn = gr.Button("Generate Exercise", variant="primary")
 
955
  with gr.Column(scale=2):
956
  with gr.Tabs():
957
  with gr.TabItem("Exercise Player"):
958
+ audio_output = gr.Audio(label="Generated Exercise", autoplay=True, type="filepath")
959
+ with gr.Row():
960
+ bpm_display = gr.Textbox(label="Tempo (BPM)")
961
+ time_sig_display = gr.Textbox(label="Time Signature")
962
+ duration_display = gr.Textbox(label="Audio Duration", interactive=False)
963
+ with gr.Row():
964
+ difficulty_rating = gr.Number(label="Difficulty Rating (1-10)", interactive=False, precision=1)
965
+ save_btn = gr.Button("Save to Library", variant="secondary")
966
+
967
+ # Metronome section
968
+ gr.Markdown("### Metronome")
969
+ with gr.Row():
970
+ metronome_tempo = gr.Slider(minimum=40, maximum=200, value=60, step=1, label="Metronome Tempo")
971
+ metronome_btn = gr.Button("Generate Metronome", variant="secondary")
972
+ metronome_audio = gr.Audio(label="Metronome", type="filepath")
973
+
974
  with gr.TabItem("Exercise Data"):
975
+ json_output = gr.Code(label="JSON Representation", language="json")
 
 
 
 
 
976
  duration_sum = gr.Number(
977
+ label="Total Duration Units (8th notes)",
978
  interactive=False,
979
  precision=0
980
  )
981
+
982
+ with gr.TabItem("Visualization"):
983
+ visualization_output = gr.Image(label="Exercise Visualization", type="filepath")
984
+ visualize_btn = gr.Button("Generate Visualization", variant="secondary")
985
+
986
  with gr.TabItem("MIDI Export"):
987
  midi_output = gr.File(label="MIDI File")
988
  download_midi = gr.Button("Generate MIDI File")
989
+
990
+ with gr.TabItem("Exercise Library"):
991
+ refresh_library_btn = gr.Button("Refresh Library", variant="secondary")
992
+ library_dropdown = gr.Dropdown([], label="Saved Exercises", interactive=True)
993
+ load_exercise_btn = gr.Button("Load Selected Exercise", variant="secondary")
994
+
995
  with gr.TabItem("AI Chat"):
996
+ chat_history = gr.Chatbot(label="Practice Assistant", height=400)
997
+ chat_message = gr.Textbox(label="Ask the AI anything about your practice")
 
 
 
998
  send_chat_btn = gr.Button("Send")
999
+ # Toggle UI groups
 
1000
  mode.change(
1001
  fn=lambda m: {
1002
  params_group: gr.update(visible=(m == "Exercise Parameters")),
1003
  prompt_group: gr.update(visible=(m == "Exercise Prompt")),
1004
  },
1005
+ inputs=[mode], outputs=[params_group, prompt_group]
 
1006
  )
 
 
1007
  def generate_caller(mode_val, instrument_val, level_val, key_val,
1008
+ time_sig_val, measures_val, prompt_val, measures_prompt_val):
1009
  real_measures = measures_prompt_val if mode_val == "Exercise Prompt" else measures_val
1010
+ fixed_tempo = 60
1011
+ json_data, mp3_path, tempo, midi, duration, time_sig, total_duration = generate_exercise(
1012
+ instrument_val, level_val, key_val, fixed_tempo, time_sig_val,
1013
+ real_measures, prompt_val, mode_val
1014
  )
1015
+
1016
+ # Calculate difficulty rating
1017
+ rating = calculate_difficulty_rating(json_data, level_val)
1018
+
1019
+ # Generate visualization
1020
+ viz_path = create_visualization(json_data, time_sig_val)
1021
+
1022
+ return json_data, mp3_path, tempo, midi, duration, time_sig, total_duration, rating, viz_path, mp3_path
1023
+
1024
  generate_btn.click(
1025
  fn=generate_caller,
1026
+ inputs=[mode, instrument, level, key, time_signature, measures, custom_prompt, measures_prompt],
1027
+ outputs=[json_output, audio_output, bpm_display, current_midi, duration_display,
1028
+ time_sig_display, duration_sum, difficulty_rating, visualization_output, current_audio_path]
 
1029
  )
1030
 
1031
+ # Visualization button
1032
+ visualize_btn.click(
1033
+ fn=create_visualization,
1034
+ inputs=[json_output, time_signature],
1035
+ outputs=[visualization_output]
1036
+ )
1037
+
1038
+ # Metronome generation
1039
+ def generate_metronome(tempo, time_sig, measures_val):
1040
+ return create_metronome_audio(tempo, time_sig, measures_val)
1041
+
1042
+ metronome_btn.click(
1043
+ fn=generate_metronome,
1044
+ inputs=[metronome_tempo, time_signature, measures],
1045
+ outputs=[metronome_audio]
1046
+ )
1047
+
1048
+ # Save to library function
1049
+ def save_to_library(json_data, instrument_val, level_val, key_val, time_sig_val, tempo_val, audio_path):
1050
+ save_exercise_to_library(
1051
+ json_data, instrument_val, level_val, key_val, time_sig_val, tempo_val, audio_path
1052
+ )
1053
+ return
1054
+
1055
+ save_btn.click(
1056
+ fn=save_to_library,
1057
+ inputs=[json_output, instrument, level, key, time_signature, bpm_display, current_audio_path],
1058
+ outputs=[]
1059
+ )
1060
+
1061
+ # Library functions
1062
+ def refresh_library():
1063
+ exercises = load_exercises_from_library()
1064
+ options = [f"{ex['timestamp']} - {ex['instrument']} ({ex['level']}) - {ex['key']} {ex['time_signature']}"
1065
+ for ex in exercises]
1066
+ return gr.Dropdown.update(choices=options, value=options[0] if options else None)
1067
+
1068
+ refresh_library_btn.click(
1069
+ fn=refresh_library,
1070
+ inputs=[],
1071
+ outputs=[library_dropdown]
1072
+ )
1073
+
1074
+ def load_exercise_from_library(selected_exercise):
1075
+ if not selected_exercise:
1076
+ return None, None, None, None, None, None, 0, None
1077
 
1078
+ exercises = load_exercises_from_library()
1079
+ for i, ex in enumerate(exercises):
1080
+ option = f"{ex['timestamp']} - {ex['instrument']} ({ex['level']}) - {ex['key']} {ex['time_signature']}"
1081
+ if option == selected_exercise:
1082
+ # Load the exercise data
1083
+ json_data = ex['json_data']
1084
+ audio_path = ex['audio_path']
1085
+ tempo = ex['tempo']
1086
+ time_sig = ex['time_signature']
1087
+
1088
+ # Calculate duration sum
1089
+ try:
1090
+ parsed = json.loads(json_data)
1091
+ total_duration = sum(item['duration'] for item in parsed if isinstance(item, dict))
1092
+ except:
1093
+ total_duration = 0
1094
+
1095
+ # Calculate difficulty rating
1096
+ rating = calculate_difficulty_rating(json_data, ex['level'])
1097
+
1098
+ # Generate visualization
1099
+ viz_path = create_visualization(json_data, time_sig)
1100
+
1101
+ # Calculate audio duration
1102
+ try:
1103
+ audio = AudioSegment.from_file(audio_path)
1104
+ duration = f"{audio.duration_seconds:.2f} seconds"
1105
+ except:
1106
+ duration = "Unknown"
1107
+
1108
+ return json_data, audio_path, tempo, duration, time_sig, total_duration, rating, viz_path
1109
 
1110
+ return None, None, None, None, None, None, 0, None
 
1111
 
1112
+ load_exercise_btn.click(
1113
+ fn=load_exercise_from_library,
1114
+ inputs=[library_dropdown],
1115
+ outputs=[json_output, audio_output, bpm_display, duration_display,
1116
+ time_sig_display, duration_sum, difficulty_rating, visualization_output]
1117
+ )
1118
 
1119
+ def save_midi(json_data, instr, time_sig, key_sig="C Major"):
1120
+ try:
1121
+ if not json_data or "Error" in json_data:
1122
+ return None
1123
+
1124
+ parsed = json.loads(json_data)
1125
+
1126
+ # Validate JSON structure
1127
+ if not isinstance(parsed, list):
1128
+ return None
1129
+
1130
+ old_format = []
1131
+ for item in parsed:
1132
+ if isinstance(item, dict) and "note" in item and "duration" in item:
1133
+ old_format.append([item["note"], item["duration"]])
1134
+
1135
+ if not old_format:
1136
+ return None
1137
+
1138
+ # Calculate total units
1139
+ total_units = sum(d[1] for d in old_format)
1140
+ numerator, denominator = map(int, time_sig.split('/'))
1141
+ units_per_measure = numerator * (8 // denominator)
1142
+ measures_est = max(1, round(total_units / units_per_measure))
1143
+
1144
+ # Generate MIDI
1145
+ cumulative = 0
1146
+ scaled_new = []
1147
+ for note, dur in old_format:
1148
+ cumulative += dur
1149
+ scaled_new.append({
1150
+ "note": note,
1151
+ "duration": dur,
1152
+ "cumulative_duration": cumulative
1153
+ })
1154
+
1155
+ midi_obj = json_to_midi(scaled_new, instr, 60, time_sig, measures_est, key=key_sig)
1156
+ midi_path = os.path.join("static", "exercise.mid")
1157
+ midi_obj.save(midi_path)
1158
+ return midi_path
1159
+ except Exception as e:
1160
+ print(f"Error saving MIDI: {e}")
1161
+ return None
1162
+
1163
  download_midi.click(
1164
  fn=save_midi,
1165
+ inputs=[json_output, instrument, time_signature, key],
1166
+ outputs=[midi_output],
1167
  )
 
 
1168
  send_chat_btn.click(
1169
  fn=handle_chat,
1170
  inputs=[chat_message, chat_history, instrument, level],
1171
+ outputs=[chat_message, chat_history],
1172
  )
 
1173
  return demo
1174
 
1175
  # -----------------------------------------------------------------------------
1176
+ # 14. Entry point
1177
  # -----------------------------------------------------------------------------
1178
  if __name__ == "__main__":
1179
  demo = create_ui()