Upload 4 files
Browse files- Dockerfile +6 -1
- README.md +1 -1
- app.py +922 -258
Dockerfile
CHANGED
@@ -5,6 +5,9 @@ RUN apt-get update && apt-get install -y \
|
|
5 |
fluidsynth \
|
6 |
libsndfile1 \
|
7 |
wget \
|
|
|
|
|
|
|
8 |
&& rm -rf /var/lib/apt/lists/*
|
9 |
|
10 |
# Create app directory
|
@@ -27,8 +30,10 @@ RUN mkdir -p /app/soundfonts && \
|
|
27 |
wget -O /app/soundfonts/Clarinet.sf2 https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2 && \
|
28 |
wget -O /app/soundfonts/Flute.sf2 https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2
|
29 |
|
30 |
-
# Create
|
31 |
RUN mkdir -p /app/static
|
|
|
|
|
32 |
|
33 |
# Set environment variables
|
34 |
ENV PYTHONUNBUFFERED=1
|
|
|
5 |
fluidsynth \
|
6 |
libsndfile1 \
|
7 |
wget \
|
8 |
+
libfreetype6-dev \
|
9 |
+
libpng-dev \
|
10 |
+
pkg-config \
|
11 |
&& rm -rf /var/lib/apt/lists/*
|
12 |
|
13 |
# Create app directory
|
|
|
30 |
wget -O /app/soundfonts/Clarinet.sf2 https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2 && \
|
31 |
wget -O /app/soundfonts/Flute.sf2 https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2
|
32 |
|
33 |
+
# Create necessary directories
|
34 |
RUN mkdir -p /app/static
|
35 |
+
RUN mkdir -p /app/exercise_library
|
36 |
+
RUN mkdir -p /app/temp_audio
|
37 |
|
38 |
# Set environment variables
|
39 |
ENV PYTHONUNBUFFERED=1
|
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: π
|
4 |
colorFrom: indigo
|
5 |
colorTo: indigo
|
|
|
1 |
---
|
2 |
+
title: Ai Music Generation Version 2
|
3 |
emoji: π
|
4 |
colorFrom: indigo
|
5 |
colorTo: indigo
|
app.py
CHANGED
@@ -1,21 +1,22 @@
|
|
1 |
"""
|
2 |
-
Adaptive Music Exercise Generator (
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
Perfectly fits user-specified number of measures and time signature
|
7 |
-
|
8 |
Major updates:
|
9 |
-
-
|
10 |
-
-
|
11 |
-
-
|
12 |
-
-
|
13 |
-
-
|
14 |
-
-
|
|
|
|
|
|
|
15 |
"""
|
16 |
|
17 |
# -----------------------------------------------------------------------------
|
18 |
-
# 1. Runtime-time package installation
|
19 |
# -----------------------------------------------------------------------------
|
20 |
import sys
|
21 |
import subprocess
|
@@ -31,7 +32,8 @@ def install(packages: List[str]):
|
|
31 |
|
32 |
install([
|
33 |
"mido", "midi2audio", "pydub", "gradio",
|
34 |
-
"requests", "
|
|
|
35 |
])
|
36 |
|
37 |
# -----------------------------------------------------------------------------
|
@@ -44,18 +46,28 @@ import tempfile
|
|
44 |
import mido
|
45 |
from mido import Message, MidiFile, MidiTrack, MetaMessage
|
46 |
import re
|
47 |
-
import
|
48 |
-
import shutil
|
49 |
-
import subprocess as sp
|
50 |
from midi2audio import FluidSynth
|
51 |
from pydub import AudioSegment
|
52 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
# -----------------------------------------------------------------------------
|
55 |
-
# 3. Configuration & constants
|
56 |
# -----------------------------------------------------------------------------
|
57 |
MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions"
|
58 |
-
MISTRAL_API_KEY = "yQdfM8MLbX9uhInQ7id4iUTwN4h4pDLX"
|
59 |
|
60 |
SOUNDFONT_URLS = {
|
61 |
"Trumpet": "https://github.com/FluidSynth/fluidsynth/raw/master/sf2/Trumpet.sf2",
|
@@ -65,8 +77,9 @@ SOUNDFONT_URLS = {
|
|
65 |
"Flute": "https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2",
|
66 |
}
|
67 |
|
68 |
-
|
69 |
-
|
|
|
70 |
|
71 |
if not os.path.exists('/usr/bin/fluidsynth'):
|
72 |
try:
|
@@ -75,75 +88,167 @@ if not os.path.exists('/usr/bin/fluidsynth'):
|
|
75 |
print("Could not install FluidSynth automatically. Please install it manually.")
|
76 |
|
77 |
os.makedirs("static", exist_ok=True)
|
|
|
|
|
78 |
|
79 |
# -----------------------------------------------------------------------------
|
80 |
-
# 4. Music theory helpers (note names βοΈ MIDI numbers)
|
81 |
# -----------------------------------------------------------------------------
|
82 |
-
NOTE_MAP = {
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
INSTRUMENT_PROGRAMS: Dict[str, int] = {
|
86 |
-
"Piano": 0,
|
87 |
-
"Clarinet": 71,
|
88 |
}
|
89 |
|
90 |
-
def
|
91 |
-
"""
|
92 |
-
|
93 |
-
n = re.sub(r'\bFLAT\b', 'B', n, flags=re.I)
|
94 |
-
n = re.sub(r'\bSHARP\b', '#', n, flags=re.I)
|
95 |
-
n = re.sub(r'([A-G])\s*-\s*FLAT', r'\1B', n, flags=re.I)
|
96 |
-
n = re.sub(r'([A-G])\s*-\s*SHARP', r'\1#', n, flags=re.I)
|
97 |
-
return n
|
98 |
|
99 |
def note_name_to_midi(note: str) -> int:
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
102 |
if not match:
|
103 |
raise ValueError(f"Invalid note: {note}")
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
if pitch not in NOTE_MAP:
|
106 |
raise ValueError(f"Invalid pitch: {pitch}")
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
# -----------------------------------------------------------------------------
|
110 |
-
# 5.
|
111 |
# -----------------------------------------------------------------------------
|
112 |
def scale_json_durations(json_data, target_units: int) -> list:
|
113 |
-
"""
|
114 |
-
|
115 |
-
total = sum(
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
# -----------------------------------------------------------------------------
|
128 |
-
# 6. MIDI from scaled JSON (integer durations)
|
129 |
# -----------------------------------------------------------------------------
|
130 |
-
def json_to_midi(json_data: list, instrument: str, tempo: int,
|
131 |
-
time_signature: str, measures: int) -> MidiFile:
|
132 |
mid = MidiFile(ticks_per_beat=TICKS_PER_BEAT)
|
133 |
track = MidiTrack(); mid.tracks.append(track)
|
134 |
program = INSTRUMENT_PROGRAMS.get(instrument, 56)
|
135 |
-
|
136 |
|
137 |
-
|
|
|
|
|
|
|
138 |
track.append(MetaMessage('set_tempo', tempo=mido.bpm2tempo(tempo), time=0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
track.append(Message('program_change', program=program, time=0))
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
return mid
|
148 |
|
149 |
# -----------------------------------------------------------------------------
|
@@ -166,58 +271,127 @@ def midi_to_mp3(midi_obj: MidiFile, instrument: str = "Trumpet") -> Tuple[str, f
|
|
166 |
wav_path = mid_file.name.replace(".mid", ".wav")
|
167 |
mp3_path = mid_file.name.replace(".mid", ".mp3")
|
168 |
sf2_path = get_soundfont(instrument)
|
169 |
-
|
170 |
try:
|
171 |
sp.run([
|
172 |
'fluidsynth', '-ni', sf2_path, mid_file.name,
|
173 |
-
'-F', wav_path, '-r', '44100'
|
174 |
], check=True, capture_output=True)
|
175 |
except Exception:
|
176 |
-
fs = FluidSynth(sf2_path, sample_rate=44100)
|
177 |
fs.midi_to_audio(mid_file.name, wav_path)
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
|
|
|
|
192 |
|
193 |
# -----------------------------------------------------------------------------
|
194 |
-
# 8.
|
195 |
# -----------------------------------------------------------------------------
|
196 |
def get_fallback_exercise(instrument: str, level: str, key: str,
|
197 |
time_sig: str, measures: int) -> str:
|
198 |
-
|
199 |
-
"
|
200 |
-
"
|
201 |
-
"
|
202 |
-
"
|
203 |
-
"
|
|
|
|
|
204 |
}
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
-
# -----------------------------------------------------------------------------
|
219 |
-
# 9. Difficulty level helpers (FIXED)
|
220 |
-
# -----------------------------------------------------------------------------
|
221 |
def get_style_based_on_level(level: str) -> str:
|
222 |
styles = {
|
223 |
"Beginner": ["simple", "legato", "stepwise"],
|
@@ -235,7 +409,7 @@ def get_technique_based_on_level(level: str) -> str:
|
|
235 |
return random.choice(techniques.get(level, ["with slurs"]))
|
236 |
|
237 |
# -----------------------------------------------------------------------------
|
238 |
-
#
|
239 |
# -----------------------------------------------------------------------------
|
240 |
def query_mistral(prompt: str, instrument: str, level: str, key: str,
|
241 |
time_sig: str, measures: int) -> str:
|
@@ -243,117 +417,213 @@ def query_mistral(prompt: str, instrument: str, level: str, key: str,
|
|
243 |
"Authorization": f"Bearer {MISTRAL_API_KEY}",
|
244 |
"Content-Type": "application/json",
|
245 |
}
|
246 |
-
numerator = int
|
247 |
-
target_units = measures * numerator
|
248 |
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
"
|
|
|
|
|
256 |
)
|
257 |
-
|
|
|
|
|
|
|
|
|
258 |
if prompt.strip():
|
259 |
-
user_prompt =
|
|
|
|
|
|
|
260 |
else:
|
261 |
-
# FIXED: Incorporate difficulty level
|
262 |
style = get_style_based_on_level(level)
|
263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
user_prompt = (
|
265 |
-
f"Create a {style} {instrument.lower()} exercise in {key}
|
266 |
-
f"{
|
|
|
|
|
|
|
|
|
|
|
267 |
)
|
268 |
-
|
269 |
payload = {
|
270 |
"model": "mistral-medium",
|
271 |
"messages": [
|
272 |
-
{"role": "system", "content":
|
273 |
-
{"role": "user", "content": user_prompt}
|
274 |
],
|
275 |
-
"temperature": 0.
|
276 |
-
"max_tokens":
|
|
|
|
|
|
|
277 |
}
|
278 |
-
|
279 |
try:
|
280 |
response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
|
281 |
response.raise_for_status()
|
282 |
-
|
|
|
283 |
except Exception as e:
|
284 |
print(f"Error querying Mistral API: {e}")
|
285 |
return get_fallback_exercise(instrument, level, key, time_sig, measures)
|
286 |
|
287 |
# -----------------------------------------------------------------------------
|
288 |
-
#
|
289 |
# -----------------------------------------------------------------------------
|
290 |
def safe_parse_json(text: str) -> Optional[list]:
|
291 |
try:
|
292 |
-
text = text.replace("'", '"')
|
293 |
-
|
294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
except Exception as e:
|
296 |
print(f"JSON parsing error: {e}\nRaw text: {text}")
|
297 |
return None
|
298 |
|
299 |
# -----------------------------------------------------------------------------
|
300 |
-
#
|
301 |
# -----------------------------------------------------------------------------
|
302 |
-
def generate_exercise(instrument: str, level: str, key: str, tempo: int,
|
303 |
-
|
304 |
-
mode: str) -> Tuple[str, Optional[str], str, MidiFile, str, str, int]:
|
305 |
try:
|
306 |
prompt_to_use = custom_prompt if mode == "Exercise Prompt" else ""
|
307 |
-
|
308 |
-
parsed = safe_parse_json(
|
309 |
-
|
310 |
if not parsed:
|
311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
|
313 |
-
# Scale to exact integer durations
|
314 |
-
numerator = int(time_signature.split('/')[0])
|
315 |
-
target_units = measures * numerator
|
316 |
-
scaled = scale_json_durations(parsed, target_units)
|
317 |
-
|
318 |
# Calculate total duration units
|
319 |
-
total_duration =
|
320 |
-
|
321 |
# Generate MIDI and audio
|
322 |
-
midi = json_to_midi(
|
323 |
mp3_path, real_duration = midi_to_mp3(midi, instrument)
|
324 |
-
|
325 |
-
return
|
326 |
-
midi, f"{real_duration:.2f} seconds", time_signature, total_duration)
|
327 |
except Exception as e:
|
328 |
return f"Error: {str(e)}", None, str(tempo), None, "0", time_signature, 0
|
329 |
|
330 |
# -----------------------------------------------------------------------------
|
331 |
-
#
|
332 |
# -----------------------------------------------------------------------------
|
333 |
def handle_chat(message: str, history: List, instrument: str, level: str):
|
334 |
if not message.strip():
|
335 |
return "", history
|
336 |
-
|
337 |
-
messages = [{"role": "system",
|
338 |
-
"content": f"You are a {instrument} teacher for {level} students."}]
|
339 |
-
|
340 |
for user_msg, assistant_msg in history:
|
341 |
messages.append({"role": "user", "content": user_msg})
|
342 |
messages.append({"role": "assistant", "content": assistant_msg})
|
343 |
-
|
344 |
messages.append({"role": "user", "content": message})
|
345 |
-
|
346 |
-
|
347 |
-
"Authorization": f"Bearer {MISTRAL_API_KEY}",
|
348 |
-
"Content-Type": "application/json"
|
349 |
-
}
|
350 |
-
payload = {
|
351 |
-
"model": "mistral-medium",
|
352 |
-
"messages": messages,
|
353 |
-
"temperature": 0.7,
|
354 |
-
"max_tokens": 500
|
355 |
-
}
|
356 |
-
|
357 |
try:
|
358 |
response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
|
359 |
response.raise_for_status()
|
@@ -365,151 +635,545 @@ def handle_chat(message: str, history: List, instrument: str, level: str):
|
|
365 |
return "", history
|
366 |
|
367 |
# -----------------------------------------------------------------------------
|
368 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
# -----------------------------------------------------------------------------
|
370 |
def create_ui() -> gr.Blocks:
|
371 |
with gr.Blocks(title="Adaptive Music Exercise Generator", theme="soft") as demo:
|
372 |
-
gr.Markdown("# πΌ Adaptive Music Exercise Generator
|
373 |
current_midi = gr.State(None)
|
374 |
current_exercise = gr.State("")
|
|
|
375 |
|
376 |
-
mode = gr.Radio(["Exercise Parameters","Exercise Prompt"],
|
377 |
-
value="Exercise Parameters", label="Generation Mode")
|
378 |
-
|
379 |
with gr.Row():
|
380 |
with gr.Column(scale=1):
|
381 |
with gr.Group(visible=True) as params_group:
|
382 |
gr.Markdown("### Exercise Parameters")
|
383 |
-
instrument = gr.Dropdown(
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
value="C Major", label="Key Signature"
|
395 |
-
)
|
396 |
-
time_signature = gr.Dropdown(
|
397 |
-
["3/4", "4/4"], value="4/4", label="Time Signature"
|
398 |
-
)
|
399 |
-
measures = gr.Radio(
|
400 |
-
[4, 8], value=4, label="Length (measures)"
|
401 |
-
)
|
402 |
-
|
403 |
with gr.Group(visible=False) as prompt_group:
|
404 |
gr.Markdown("### Exercise Prompt")
|
405 |
-
custom_prompt = gr.Textbox(
|
406 |
-
|
407 |
-
label="Describe your exercise (e.g. 'Jazz trumpet exercise with syncopation')",
|
408 |
-
lines=3
|
409 |
-
)
|
410 |
-
measures_prompt = gr.Radio(
|
411 |
-
[4, 8], value=4, label="Length (measures)"
|
412 |
-
)
|
413 |
-
|
414 |
generate_btn = gr.Button("Generate Exercise", variant="primary")
|
415 |
-
|
416 |
with gr.Column(scale=2):
|
417 |
with gr.Tabs():
|
418 |
with gr.TabItem("Exercise Player"):
|
419 |
-
audio_output = gr.Audio(label="Generated Exercise",
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
with gr.TabItem("Exercise Data"):
|
427 |
-
json_output = gr.Code(
|
428 |
-
label="JSON Representation",
|
429 |
-
language="json",
|
430 |
-
interactive=True
|
431 |
-
)
|
432 |
-
# NEW: Duration sum display
|
433 |
duration_sum = gr.Number(
|
434 |
-
label="Total Duration Units (
|
435 |
interactive=False,
|
436 |
precision=0
|
437 |
)
|
438 |
-
|
|
|
|
|
|
|
|
|
439 |
with gr.TabItem("MIDI Export"):
|
440 |
midi_output = gr.File(label="MIDI File")
|
441 |
download_midi = gr.Button("Generate MIDI File")
|
442 |
-
|
|
|
|
|
|
|
|
|
|
|
443 |
with gr.TabItem("AI Chat"):
|
444 |
-
chat_history = gr.Chatbot(label="Practice Assistant",
|
445 |
-
|
446 |
-
chat_message = gr.Textbox(
|
447 |
-
label="Ask about technique, theory, or practice strategies"
|
448 |
-
)
|
449 |
send_chat_btn = gr.Button("Send")
|
450 |
-
|
451 |
-
# UI visibility toggling
|
452 |
mode.change(
|
453 |
fn=lambda m: {
|
454 |
params_group: gr.update(visible=(m == "Exercise Parameters")),
|
455 |
prompt_group: gr.update(visible=(m == "Exercise Prompt")),
|
456 |
},
|
457 |
-
inputs=[mode],
|
458 |
-
outputs=[params_group, prompt_group]
|
459 |
)
|
460 |
-
|
461 |
-
# Generate exercise handler
|
462 |
def generate_caller(mode_val, instrument_val, level_val, key_val,
|
463 |
-
|
464 |
real_measures = measures_prompt_val if mode_val == "Exercise Prompt" else measures_val
|
465 |
-
fixed_tempo = 60
|
466 |
-
|
467 |
-
instrument_val, level_val, key_val, fixed_tempo,
|
468 |
-
|
469 |
)
|
470 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
471 |
generate_btn.click(
|
472 |
fn=generate_caller,
|
473 |
-
inputs=[mode, instrument, level, key, time_signature,
|
474 |
-
|
475 |
-
|
476 |
-
current_midi, duration_display, time_sig_display, duration_sum]
|
477 |
)
|
478 |
|
479 |
-
#
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
484 |
|
485 |
-
|
486 |
-
|
487 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
488 |
|
489 |
-
|
490 |
-
midi_obj = json_to_midi(scaled, instr, 60, time_sig, measures_est)
|
491 |
|
492 |
-
|
493 |
-
|
494 |
-
|
|
|
|
|
|
|
495 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
496 |
download_midi.click(
|
497 |
fn=save_midi,
|
498 |
-
inputs=[json_output, instrument, time_signature],
|
499 |
-
outputs=[midi_output]
|
500 |
)
|
501 |
-
|
502 |
-
# Chat handler
|
503 |
send_chat_btn.click(
|
504 |
fn=handle_chat,
|
505 |
inputs=[chat_message, chat_history, instrument, level],
|
506 |
-
outputs=[chat_message, chat_history]
|
507 |
)
|
508 |
-
|
509 |
return demo
|
510 |
|
511 |
# -----------------------------------------------------------------------------
|
512 |
-
#
|
513 |
# -----------------------------------------------------------------------------
|
514 |
if __name__ == "__main__":
|
515 |
demo = create_ui()
|
|
|
1 |
"""
|
2 |
+
Adaptive Music Exercise Generator (Strict Duration Enforcement)
|
3 |
+
==============================================================
|
4 |
+
Generates custom musical exercises with LLM, perfectly fit to user-specified number of measures
|
5 |
+
AND time signature, guaranteeing exact durations in MIDI and in the UI!
|
|
|
|
|
6 |
Major updates:
|
7 |
+
- Changed base duration unit from 16th notes to 8th notes (1 unit = 8th note)
|
8 |
+
- Updated all calculations and prompts to use new duration system
|
9 |
+
- Duration sum display now shows total in 8th notes
|
10 |
+
- Maintained all original functionality
|
11 |
+
- Added cumulative duration tracking
|
12 |
+
- Enforced JSON output format with note, duration, cumulative_duration
|
13 |
+
- Enhanced rest handling and JSON parsing
|
14 |
+
- Fixed JSON parsing errors for 8-measure exercises
|
15 |
+
- Added robust error handling for MIDI generation
|
16 |
"""
|
17 |
|
18 |
# -----------------------------------------------------------------------------
|
19 |
+
# 1. Runtime-time package installation (for fresh containers/Colab/etc)
|
20 |
# -----------------------------------------------------------------------------
|
21 |
import sys
|
22 |
import subprocess
|
|
|
32 |
|
33 |
install([
|
34 |
"mido", "midi2audio", "pydub", "gradio",
|
35 |
+
"requests", "numpy", "matplotlib", "librosa", "scipy",
|
36 |
+
"uuid", "datetime"
|
37 |
])
|
38 |
|
39 |
# -----------------------------------------------------------------------------
|
|
|
46 |
import mido
|
47 |
from mido import Message, MidiFile, MidiTrack, MetaMessage
|
48 |
import re
|
49 |
+
from io import BytesIO
|
|
|
|
|
50 |
from midi2audio import FluidSynth
|
51 |
from pydub import AudioSegment
|
52 |
import gradio as gr
|
53 |
+
import numpy as np
|
54 |
+
import matplotlib.pyplot as plt
|
55 |
+
import librosa
|
56 |
+
from scipy.io import wavfile
|
57 |
+
import os
|
58 |
+
import subprocess as sp
|
59 |
+
import base64
|
60 |
+
import shutil
|
61 |
+
import ast
|
62 |
+
import uuid
|
63 |
+
from datetime import datetime
|
64 |
+
import time
|
65 |
|
66 |
# -----------------------------------------------------------------------------
|
67 |
+
# 3. Configuration & constants (UPDATED TO USE 8TH NOTES)
|
68 |
# -----------------------------------------------------------------------------
|
69 |
MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions"
|
70 |
+
MISTRAL_API_KEY = "yQdfM8MLbX9uhInQ7id4iUTwN4h4pDLX" # β Replace with your key!
|
71 |
|
72 |
SOUNDFONT_URLS = {
|
73 |
"Trumpet": "https://github.com/FluidSynth/fluidsynth/raw/master/sf2/Trumpet.sf2",
|
|
|
77 |
"Flute": "https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2",
|
78 |
}
|
79 |
|
80 |
+
SAMPLE_RATE = 44100 # Hz
|
81 |
+
TICKS_PER_BEAT = 480 # Standard MIDI resolution
|
82 |
+
TICKS_PER_8TH = TICKS_PER_BEAT // 2 # 240 ticks per 8th note (UPDATED)
|
83 |
|
84 |
if not os.path.exists('/usr/bin/fluidsynth'):
|
85 |
try:
|
|
|
88 |
print("Could not install FluidSynth automatically. Please install it manually.")
|
89 |
|
90 |
os.makedirs("static", exist_ok=True)
|
91 |
+
os.makedirs("temp_audio", exist_ok=True)
|
92 |
+
os.makedirs("saved_exercises", exist_ok=True)
|
93 |
|
94 |
# -----------------------------------------------------------------------------
|
95 |
+
# 4. Music theory helpers (note names βοΈ MIDI numbers) - ENHANCED REST HANDLING
|
96 |
# -----------------------------------------------------------------------------
|
97 |
+
NOTE_MAP: Dict[str, int] = {
|
98 |
+
"C": 0, "C#": 1, "DB": 1,
|
99 |
+
"D": 2, "D#": 3, "EB": 3,
|
100 |
+
"E": 4, "F": 5, "F#": 6, "GB": 6,
|
101 |
+
"G": 7, "G#": 8, "AB": 8,
|
102 |
+
"A": 9, "A#": 10, "BB": 10,
|
103 |
+
"B": 11,
|
104 |
+
}
|
105 |
+
|
106 |
+
REST_INDICATORS = ["rest", "r", "Rest", "R", "P", "p", "pause"]
|
107 |
|
108 |
INSTRUMENT_PROGRAMS: Dict[str, int] = {
|
109 |
+
"Piano": 0, "Trumpet": 56, "Violin": 40,
|
110 |
+
"Clarinet": 71, "Flute": 73,
|
111 |
}
|
112 |
|
113 |
+
def is_rest(note: str) -> bool:
|
114 |
+
"""Check if a note string represents a rest."""
|
115 |
+
return note.strip().lower() in [r.lower() for r in REST_INDICATORS]
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
def note_name_to_midi(note: str) -> int:
|
118 |
+
if is_rest(note):
|
119 |
+
return -1 # Special value for rests
|
120 |
+
|
121 |
+
# Allow both scientific (C4) and Helmholtz (C') notation
|
122 |
+
match = re.match(r"([A-Ga-g][#b]?)(\'*)(\d?)", note)
|
123 |
if not match:
|
124 |
raise ValueError(f"Invalid note: {note}")
|
125 |
+
|
126 |
+
pitch, apostrophes, octave = match.groups()
|
127 |
+
pitch = pitch.upper().replace('b', 'B')
|
128 |
+
|
129 |
+
# Handle Helmholtz notation (C' = C5, C'' = C6, etc)
|
130 |
+
octave_num = 4
|
131 |
+
if octave:
|
132 |
+
octave_num = int(octave)
|
133 |
+
elif apostrophes:
|
134 |
+
octave_num = 5 + len(apostrophes)
|
135 |
+
|
136 |
if pitch not in NOTE_MAP:
|
137 |
raise ValueError(f"Invalid pitch: {pitch}")
|
138 |
+
|
139 |
+
return NOTE_MAP[pitch] + (octave_num + 1) * 12
|
140 |
+
|
141 |
+
def midi_to_note_name(midi_num: int) -> str:
|
142 |
+
if midi_num == -1:
|
143 |
+
return "Rest"
|
144 |
+
notes = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
|
145 |
+
octave = (midi_num // 12) - 1
|
146 |
+
return f"{notes[midi_num % 12]}{octave}"
|
147 |
|
148 |
# -----------------------------------------------------------------------------
|
149 |
+
# 5. Duration scaling: guarantee the output sums to requested total (using integers)
|
150 |
# -----------------------------------------------------------------------------
|
151 |
def scale_json_durations(json_data, target_units: int) -> list:
|
152 |
+
"""Scales durations so that their sum is exactly target_units (8th notes)."""
|
153 |
+
durations = [int(d) for _, d in json_data]
|
154 |
+
total = sum(durations)
|
155 |
+
if total == 0:
|
156 |
+
return json_data
|
157 |
+
|
158 |
+
# Calculate proportional scaling with integer arithmetic
|
159 |
+
scaled = []
|
160 |
+
remainder = target_units
|
161 |
+
for i, (note, d) in enumerate(json_data):
|
162 |
+
if i < len(json_data) - 1:
|
163 |
+
# Proportional allocation
|
164 |
+
portion = max(1, round(d * target_units / total))
|
165 |
+
scaled.append([note, portion])
|
166 |
+
remainder -= portion
|
167 |
+
else:
|
168 |
+
# Last note gets all remaining units
|
169 |
+
scaled.append([note, max(1, remainder)])
|
170 |
+
|
171 |
+
return scaled
|
172 |
|
173 |
# -----------------------------------------------------------------------------
|
174 |
+
# 6. MIDI from scaled JSON (using integer durations) - UPDATED REST HANDLING
|
175 |
# -----------------------------------------------------------------------------
|
176 |
+
def json_to_midi(json_data: list, instrument: str, tempo: int, time_signature: str, measures: int, key: str = "C Major") -> MidiFile:
|
|
|
177 |
mid = MidiFile(ticks_per_beat=TICKS_PER_BEAT)
|
178 |
track = MidiTrack(); mid.tracks.append(track)
|
179 |
program = INSTRUMENT_PROGRAMS.get(instrument, 56)
|
180 |
+
numerator, denominator = map(int, time_signature.split('/'))
|
181 |
|
182 |
+
# Add time signature meta message
|
183 |
+
track.append(MetaMessage('time_signature', numerator=numerator,
|
184 |
+
denominator=denominator, time=0))
|
185 |
+
# Add tempo meta message
|
186 |
track.append(MetaMessage('set_tempo', tempo=mido.bpm2tempo(tempo), time=0))
|
187 |
+
|
188 |
+
# Add key signature meta message based on the key
|
189 |
+
# For MIDI key signatures, the key parameter expects a string like 'C', 'F#m', etc.
|
190 |
+
key_map = {
|
191 |
+
"C Major": "C",
|
192 |
+
"G Major": "G",
|
193 |
+
"D Major": "D",
|
194 |
+
"F Major": "F",
|
195 |
+
"Bb Major": "Bb",
|
196 |
+
"A Minor": "Am",
|
197 |
+
"E Minor": "Em",
|
198 |
+
}
|
199 |
+
|
200 |
+
# Use the provided key or default to C major if key not found
|
201 |
+
midi_key = key_map.get(key, "C")
|
202 |
+
# The 'key' parameter in MetaMessage expects a string like 'C', 'F#m', etc.
|
203 |
+
track.append(MetaMessage('key_signature', key=midi_key, time=0))
|
204 |
+
|
205 |
+
# Set instrument program
|
206 |
track.append(Message('program_change', program=program, time=0))
|
207 |
|
208 |
+
# Accumulator for rest durations
|
209 |
+
accumulated_rest = 0
|
210 |
+
|
211 |
+
for note_item in json_data:
|
212 |
+
try:
|
213 |
+
# Handle both formats: [note, duration] and {note, duration}
|
214 |
+
if isinstance(note_item, list) and len(note_item) == 2:
|
215 |
+
note_name, duration_units = note_item
|
216 |
+
elif isinstance(note_item, dict):
|
217 |
+
note_name = note_item["note"]
|
218 |
+
duration_units = note_item["duration"]
|
219 |
+
else:
|
220 |
+
print(f"Unsupported note format: {note_item}")
|
221 |
+
continue
|
222 |
+
|
223 |
+
ticks = int(duration_units * TICKS_PER_8TH)
|
224 |
+
ticks = max(ticks, 1)
|
225 |
+
|
226 |
+
if is_rest(note_name):
|
227 |
+
# Accumulate rest duration
|
228 |
+
accumulated_rest += ticks
|
229 |
+
else:
|
230 |
+
# Process any accumulated rest first
|
231 |
+
if accumulated_rest > 0:
|
232 |
+
# Add rest by creating a silent note (velocity 0) that won't be heard
|
233 |
+
# Or just skip and use accumulated_rest in timing
|
234 |
+
# We'll just add the time to the next note
|
235 |
+
track.append(Message('note_on', note=0, velocity=0, time=accumulated_rest))
|
236 |
+
track.append(Message('note_off', note=0, velocity=0, time=0))
|
237 |
+
accumulated_rest = 0
|
238 |
+
|
239 |
+
# Process actual note
|
240 |
+
note_num = note_name_to_midi(note_name)
|
241 |
+
velocity = random.randint(60, 100)
|
242 |
+
track.append(Message('note_on', note=note_num, velocity=velocity, time=0))
|
243 |
+
track.append(Message('note_off', note=note_num, velocity=velocity, time=ticks))
|
244 |
+
except Exception as e:
|
245 |
+
print(f"Error parsing note {note_item}: {e}")
|
246 |
+
|
247 |
+
# Handle trailing rest
|
248 |
+
if accumulated_rest > 0:
|
249 |
+
track.append(Message('note_on', note=0, velocity=0, time=accumulated_rest))
|
250 |
+
track.append(Message('note_off', note=0, velocity=0, time=0))
|
251 |
+
|
252 |
return mid
|
253 |
|
254 |
# -----------------------------------------------------------------------------
|
|
|
271 |
wav_path = mid_file.name.replace(".mid", ".wav")
|
272 |
mp3_path = mid_file.name.replace(".mid", ".mp3")
|
273 |
sf2_path = get_soundfont(instrument)
|
|
|
274 |
try:
|
275 |
sp.run([
|
276 |
'fluidsynth', '-ni', sf2_path, mid_file.name,
|
277 |
+
'-F', wav_path, '-r', '44100', '-g', '1.0'
|
278 |
], check=True, capture_output=True)
|
279 |
except Exception:
|
280 |
+
fs = FluidSynth(sf2_path, sample_rate=44100, gain=1.0)
|
281 |
fs.midi_to_audio(mid_file.name, wav_path)
|
282 |
+
try:
|
283 |
+
sound = AudioSegment.from_wav(wav_path)
|
284 |
+
if instrument == "Trumpet":
|
285 |
+
sound = sound.high_pass_filter(200)
|
286 |
+
elif instrument == "Violin":
|
287 |
+
sound = sound.low_pass_filter(5000)
|
288 |
+
sound.export(mp3_path, format="mp3")
|
289 |
+
static_mp3_path = os.path.join('static', os.path.basename(mp3_path))
|
290 |
+
shutil.move(mp3_path, static_mp3_path)
|
291 |
+
return static_mp3_path, sound.duration_seconds
|
292 |
+
finally:
|
293 |
+
for f in [mid_file.name, wav_path]:
|
294 |
+
try:
|
295 |
+
os.remove(f)
|
296 |
+
except FileNotFoundError:
|
297 |
+
pass
|
298 |
|
299 |
# -----------------------------------------------------------------------------
|
300 |
+
# 8. Prompt engineering for variety (using integer durations) - UPDATED DURATION SYSTEM
|
301 |
# -----------------------------------------------------------------------------
|
302 |
def get_fallback_exercise(instrument: str, level: str, key: str,
|
303 |
time_sig: str, measures: int) -> str:
|
304 |
+
key_notes = {
|
305 |
+
"C Major": ["C4", "D4", "E4", "F4", "G4", "A4", "B4"],
|
306 |
+
"G Major": ["G3", "A3", "B3", "C4", "D4", "E4", "F#4"],
|
307 |
+
"D Major": ["D4", "E4", "F#4", "G4", "A4", "B4", "C#5"],
|
308 |
+
"F Major": ["F3", "G3", "A3", "Bb3", "C4", "D4", "E4"],
|
309 |
+
"Bb Major": ["Bb3", "C4", "D4", "Eb4", "F4", "G4", "A4"],
|
310 |
+
"A Minor": ["A3", "B3", "C4", "D4", "E4", "F4", "G4"],
|
311 |
+
"E Minor": ["E3", "F#3", "G3", "A3", "B3", "C4", "D4"],
|
312 |
}
|
313 |
+
|
314 |
+
# Get fundamental note from key signature
|
315 |
+
fundamental_note = key.split()[0] # Gets 'C' from 'C Major' or 'A' from 'A Minor'
|
316 |
+
is_major = "Major" in key
|
317 |
+
|
318 |
+
# Get notes for the key
|
319 |
+
notes = key_notes.get(key, key_notes["C Major"])
|
320 |
+
|
321 |
+
# Find fundamental note with octave for ending
|
322 |
+
fundamental_with_octave = None
|
323 |
+
for note in notes:
|
324 |
+
if note.startswith(fundamental_note):
|
325 |
+
fundamental_with_octave = note
|
326 |
+
break
|
327 |
+
|
328 |
+
# If not found, use the first note (should not happen with our key definitions)
|
329 |
+
if not fundamental_with_octave:
|
330 |
+
fundamental_with_octave = notes[0]
|
331 |
+
|
332 |
+
numerator, denominator = map(int, time_sig.split('/'))
|
333 |
+
|
334 |
+
# Calculate units based on 8th notes
|
335 |
+
units_per_measure = numerator * (8 // denominator)
|
336 |
+
target_units = measures * units_per_measure
|
337 |
+
|
338 |
+
# Create a rhythm pattern based on time signature
|
339 |
+
if numerator == 3:
|
340 |
+
rhythm = [2, 1, 2, 1, 2] # 3/4 pattern
|
341 |
+
else:
|
342 |
+
rhythm = [2, 2, 1, 1, 2, 2] # 4/4 pattern
|
343 |
+
|
344 |
+
# Build exercise
|
345 |
+
result = []
|
346 |
+
cumulative = 0
|
347 |
+
current_units = 0
|
348 |
+
|
349 |
+
# Reserve at least 2 units for the final note
|
350 |
+
final_note_duration = min(4, max(2, rhythm[0])) # Between 2 and 4 units
|
351 |
+
available_units = target_units - final_note_duration
|
352 |
+
|
353 |
+
# Generate notes until we reach the available units
|
354 |
+
while current_units < available_units:
|
355 |
+
# Avoid minor 7th in major keys
|
356 |
+
if is_major:
|
357 |
+
# Filter out minor 7th notes (e.g., Bb in C major)
|
358 |
+
available_notes = [n for n in notes if not (n.startswith("Bb") and key == "C Major") and
|
359 |
+
not (n.startswith("F") and key == "G Major") and
|
360 |
+
not (n.startswith("C") and key == "D Major") and
|
361 |
+
not (n.startswith("Eb") and key == "F Major") and
|
362 |
+
not (n.startswith("Ab") and key == "Bb Major")]
|
363 |
+
else:
|
364 |
+
available_notes = notes
|
365 |
+
|
366 |
+
note = random.choice(available_notes)
|
367 |
+
dur = random.choice(rhythm)
|
368 |
+
|
369 |
+
# Don't exceed available units
|
370 |
+
if current_units + dur > available_units:
|
371 |
+
dur = available_units - current_units
|
372 |
+
if dur <= 0:
|
373 |
+
break
|
374 |
+
|
375 |
+
cumulative += dur
|
376 |
+
current_units += dur
|
377 |
+
result.append({
|
378 |
+
"note": note,
|
379 |
+
"duration": dur,
|
380 |
+
"cumulative_duration": cumulative
|
381 |
+
})
|
382 |
+
|
383 |
+
# Add the final note (fundamental of the key)
|
384 |
+
final_duration = target_units - current_units
|
385 |
+
if final_duration > 0:
|
386 |
+
cumulative += final_duration
|
387 |
+
result.append({
|
388 |
+
"note": fundamental_with_octave,
|
389 |
+
"duration": final_duration,
|
390 |
+
"cumulative_duration": cumulative
|
391 |
+
})
|
392 |
+
|
393 |
+
return json.dumps(result)
|
394 |
|
|
|
|
|
|
|
395 |
def get_style_based_on_level(level: str) -> str:
|
396 |
styles = {
|
397 |
"Beginner": ["simple", "legato", "stepwise"],
|
|
|
409 |
return random.choice(techniques.get(level, ["with slurs"]))
|
410 |
|
411 |
# -----------------------------------------------------------------------------
|
412 |
+
# 9. Mistral API: query, fallback on errors - UPDATED DURATION SYSTEM
|
413 |
# -----------------------------------------------------------------------------
|
414 |
def query_mistral(prompt: str, instrument: str, level: str, key: str,
|
415 |
time_sig: str, measures: int) -> str:
|
|
|
417 |
"Authorization": f"Bearer {MISTRAL_API_KEY}",
|
418 |
"Content-Type": "application/json",
|
419 |
}
|
420 |
+
numerator, denominator = map(int, time_sig.split('/'))
|
|
|
421 |
|
422 |
+
# UPDATED: Calculate total required 8th notes
|
423 |
+
units_per_measure = numerator * (8 // denominator)
|
424 |
+
required_total = measures * units_per_measure
|
425 |
+
|
426 |
+
# UPDATED: Duration explanation in prompt
|
427 |
+
duration_constraint = (
|
428 |
+
f"Sum of all durations MUST BE EXACTLY {required_total} units (8th notes). "
|
429 |
+
f"Each integer duration represents an 8th note (1=8th, 2=quarter, 4=half, 8=whole). "
|
430 |
+
f"If it doesn't match, the exercise is invalid."
|
431 |
)
|
432 |
+
system_prompt = (
|
433 |
+
f"You are an expert music teacher specializing in {instrument.lower()}. "
|
434 |
+
"Create customized exercises using INTEGER durations representing 8th notes."
|
435 |
+
)
|
436 |
+
|
437 |
if prompt.strip():
|
438 |
+
user_prompt = (
|
439 |
+
f"{prompt} {duration_constraint} Output ONLY a JSON array of objects with "
|
440 |
+
"the following structure: [{{'note': string, 'duration': integer, 'cumulative_duration': integer}}]"
|
441 |
+
)
|
442 |
else:
|
|
|
443 |
style = get_style_based_on_level(level)
|
444 |
+
technique = get_technique_based_on_level(level)
|
445 |
+
# Extract fundamental note from key signature
|
446 |
+
fundamental_note = key.split()[0] # Gets 'C' from 'C Major' or 'A' from 'A Minor'
|
447 |
+
is_major = "Major" in key
|
448 |
+
|
449 |
+
# Create additional musical constraints
|
450 |
+
key_constraints = (
|
451 |
+
f"The exercise MUST end on the fundamental note of the key ({fundamental_note}). "
|
452 |
+
f"{'' if not is_major else 'For this major key, avoid using the minor 7th degree.'}"
|
453 |
+
)
|
454 |
+
|
455 |
user_prompt = (
|
456 |
+
f"Create a {style} {instrument.lower()} exercise in {key} with {time_sig} time signature "
|
457 |
+
f"{technique} for a {level.lower()} player. {duration_constraint} {key_constraints} "
|
458 |
+
"Output ONLY a JSON array of objects with the following structure: "
|
459 |
+
"[{{'note': string, 'duration': integer, 'cumulative_duration': integer}}] "
|
460 |
+
"Use standard note names (e.g., \"Bb4\", \"F#5\"). Monophonic only. "
|
461 |
+
"Durations: 1=8th, 2=quarter, 4=half, 8=whole. "
|
462 |
+
"Sum must be exactly as specified. ONLY output the JSON array. No prose."
|
463 |
)
|
464 |
+
|
465 |
payload = {
|
466 |
"model": "mistral-medium",
|
467 |
"messages": [
|
468 |
+
{"role": "system", "content": system_prompt},
|
469 |
+
{"role": "user", "content": user_prompt},
|
470 |
],
|
471 |
+
"temperature": 0.7 if level == "Advanced" else 0.5,
|
472 |
+
"max_tokens": 1000,
|
473 |
+
"top_p": 0.95,
|
474 |
+
"frequency_penalty": 0.2,
|
475 |
+
"presence_penalty": 0.2,
|
476 |
}
|
477 |
+
|
478 |
try:
|
479 |
response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
|
480 |
response.raise_for_status()
|
481 |
+
content = response.json()["choices"][0]["message"]["content"]
|
482 |
+
return content.replace("```json","").replace("```","").strip()
|
483 |
except Exception as e:
|
484 |
print(f"Error querying Mistral API: {e}")
|
485 |
return get_fallback_exercise(instrument, level, key, time_sig, measures)
|
486 |
|
487 |
# -----------------------------------------------------------------------------
|
488 |
+
# 10. Robust JSON parsing for LLM outputs - ENHANCED PARSING
|
489 |
# -----------------------------------------------------------------------------
|
490 |
def safe_parse_json(text: str) -> Optional[list]:
|
491 |
try:
|
492 |
+
text = text.strip().replace("'", '"')
|
493 |
+
|
494 |
+
# Find JSON array in the text
|
495 |
+
start_idx = text.find('[')
|
496 |
+
end_idx = text.rfind(']')
|
497 |
+
if start_idx == -1 or end_idx == -1:
|
498 |
+
return None
|
499 |
+
|
500 |
+
json_str = text[start_idx:end_idx+1]
|
501 |
+
|
502 |
+
# Fix common JSON issues
|
503 |
+
json_str = re.sub(r',\s*([}\]])', r'\1', json_str) # Trailing commas
|
504 |
+
json_str = re.sub(r'{\s*(\w+)\s*:', r'{"\1":', json_str) # Unquoted keys
|
505 |
+
json_str = re.sub(r':\s*([a-zA-Z_][a-zA-Z0-9_]*)(\s*[,}])', r':"\1"\2', json_str) # Unquoted strings
|
506 |
+
|
507 |
+
parsed = json.loads(json_str)
|
508 |
+
|
509 |
+
# Normalize keys to 'note' and 'duration'
|
510 |
+
normalized = []
|
511 |
+
for item in parsed:
|
512 |
+
if isinstance(item, dict):
|
513 |
+
# Find note value - accept multiple keys
|
514 |
+
note_val = None
|
515 |
+
for key in ['note', 'pitch', 'nota', 'ton']:
|
516 |
+
if key in item:
|
517 |
+
note_val = str(item[key])
|
518 |
+
break
|
519 |
+
|
520 |
+
# Find duration value
|
521 |
+
dur_val = None
|
522 |
+
for key in ['duration', 'dur', 'length', 'value']:
|
523 |
+
if key in item:
|
524 |
+
try:
|
525 |
+
dur_val = int(item[key])
|
526 |
+
except (TypeError, ValueError):
|
527 |
+
pass
|
528 |
+
|
529 |
+
if note_val is not None and dur_val is not None:
|
530 |
+
normalized.append({"note": note_val, "duration": dur_val})
|
531 |
+
|
532 |
+
return normalized if normalized else None
|
533 |
+
|
534 |
except Exception as e:
|
535 |
print(f"JSON parsing error: {e}\nRaw text: {text}")
|
536 |
return None
|
537 |
|
538 |
# -----------------------------------------------------------------------------
|
539 |
+
# 11. Main orchestration: talk to API, *scale durations*, build MIDI, UI values - UPDATED
|
540 |
# -----------------------------------------------------------------------------
|
541 |
+
def generate_exercise(instrument: str, level: str, key: str, tempo: int, time_signature: str,
|
542 |
+
measures: int, custom_prompt: str, mode: str) -> Tuple[str, Optional[str], str, MidiFile, str, str, int]:
|
|
|
543 |
try:
|
544 |
prompt_to_use = custom_prompt if mode == "Exercise Prompt" else ""
|
545 |
+
output = query_mistral(prompt_to_use, instrument, level, key, time_signature, measures)
|
546 |
+
parsed = safe_parse_json(output)
|
|
|
547 |
if not parsed:
|
548 |
+
print("Primary parsing failed, using fallback")
|
549 |
+
fallback_str = get_fallback_exercise(instrument, level, key, time_signature, measures)
|
550 |
+
parsed = safe_parse_json(fallback_str)
|
551 |
+
if not parsed:
|
552 |
+
print("Fallback parsing failed, using ultimate fallback")
|
553 |
+
# Ultimate fallback: simple scale based on selected key
|
554 |
+
key_notes = {
|
555 |
+
"C Major": ["C4", "D4", "E4", "F4", "G4", "A4", "B4", "C5"],
|
556 |
+
"G Major": ["G3", "A3", "B3", "C4", "D4", "E4", "F#4", "G4"],
|
557 |
+
"D Major": ["D4", "E4", "F#4", "G4", "A4", "B4", "C#5", "D5"],
|
558 |
+
"F Major": ["F3", "G3", "A3", "Bb3", "C4", "D4", "E4", "F4"],
|
559 |
+
"Bb Major": ["Bb3", "C4", "D4", "Eb4", "F4", "G4", "A4", "Bb4"],
|
560 |
+
"A Minor": ["A3", "B3", "C4", "D4", "E4", "F4", "G4", "A4"],
|
561 |
+
"E Minor": ["E3", "F#3", "G3", "A3", "B3", "C4", "D4", "E4"],
|
562 |
+
}
|
563 |
+
notes = key_notes.get(key, key_notes["C Major"])
|
564 |
+
numerator, denominator = map(int, time_signature.split('/'))
|
565 |
+
units_per_measure = numerator * (8 // denominator)
|
566 |
+
target_units = measures * units_per_measure
|
567 |
+
note_duration = max(1, target_units // len(notes))
|
568 |
+
parsed = [{"note": n, "duration": note_duration} for n in notes]
|
569 |
+
# Adjust last note to match total duration
|
570 |
+
total = sum(item["duration"] for item in parsed)
|
571 |
+
if total < target_units:
|
572 |
+
parsed[-1]["duration"] += target_units - total
|
573 |
+
elif total > target_units:
|
574 |
+
parsed[-1]["duration"] -= total - target_units
|
575 |
+
|
576 |
+
# Calculate total required 8th notes (UPDATED)
|
577 |
+
numerator, denominator = map(int, time_signature.split('/'))
|
578 |
+
units_per_measure = numerator * (8 // denominator)
|
579 |
+
total_units = measures * units_per_measure
|
580 |
+
|
581 |
+
# Convert to old format for scaling
|
582 |
+
old_format = []
|
583 |
+
for item in parsed:
|
584 |
+
if isinstance(item, dict):
|
585 |
+
old_format.append([item["note"], item["duration"]])
|
586 |
+
else:
|
587 |
+
old_format.append(item)
|
588 |
+
|
589 |
+
# Strict scaling
|
590 |
+
parsed_scaled_old = scale_json_durations(old_format, total_units)
|
591 |
+
|
592 |
+
# Convert back to new format with cumulative durations
|
593 |
+
cumulative = 0
|
594 |
+
parsed_scaled = []
|
595 |
+
for note, dur in parsed_scaled_old:
|
596 |
+
cumulative += dur
|
597 |
+
parsed_scaled.append({
|
598 |
+
"note": note,
|
599 |
+
"duration": dur,
|
600 |
+
"cumulative_duration": cumulative
|
601 |
+
})
|
602 |
|
|
|
|
|
|
|
|
|
|
|
603 |
# Calculate total duration units
|
604 |
+
total_duration = cumulative
|
605 |
+
|
606 |
# Generate MIDI and audio
|
607 |
+
midi = json_to_midi(parsed_scaled, instrument, tempo, time_signature, measures, key)
|
608 |
mp3_path, real_duration = midi_to_mp3(midi, instrument)
|
609 |
+
output_json_str = json.dumps(parsed_scaled, indent=2)
|
610 |
+
return output_json_str, mp3_path, str(tempo), midi, f"{real_duration:.2f} seconds", time_signature, total_duration
|
|
|
611 |
except Exception as e:
|
612 |
return f"Error: {str(e)}", None, str(tempo), None, "0", time_signature, 0
|
613 |
|
614 |
# -----------------------------------------------------------------------------
|
615 |
+
# 12. Simple AI chat assistant (optional, shares LLM)
|
616 |
# -----------------------------------------------------------------------------
|
617 |
def handle_chat(message: str, history: List, instrument: str, level: str):
|
618 |
if not message.strip():
|
619 |
return "", history
|
620 |
+
messages = [{"role": "system", "content": f"You are a {instrument} teacher for {level} students."}]
|
|
|
|
|
|
|
621 |
for user_msg, assistant_msg in history:
|
622 |
messages.append({"role": "user", "content": user_msg})
|
623 |
messages.append({"role": "assistant", "content": assistant_msg})
|
|
|
624 |
messages.append({"role": "user", "content": message})
|
625 |
+
headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"}
|
626 |
+
payload = {"model": "mistral-medium", "messages": messages, "temperature": 0.7, "max_tokens": 500}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
627 |
try:
|
628 |
response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
|
629 |
response.raise_for_status()
|
|
|
635 |
return "", history
|
636 |
|
637 |
# -----------------------------------------------------------------------------
|
638 |
+
# 13. New features: Visualization, Metronome, and Exercise Library
|
639 |
+
# -----------------------------------------------------------------------------
|
640 |
+
|
641 |
+
# Visualization function to create a piano roll representation of the exercise
|
642 |
+
def create_visualization(json_data, time_sig):
|
643 |
+
try:
|
644 |
+
if not json_data or "Error" in json_data:
|
645 |
+
return None
|
646 |
+
|
647 |
+
parsed = json.loads(json_data)
|
648 |
+
if not isinstance(parsed, list) or len(parsed) == 0:
|
649 |
+
return None
|
650 |
+
|
651 |
+
# Extract notes and durations
|
652 |
+
notes = []
|
653 |
+
durations = []
|
654 |
+
for item in parsed:
|
655 |
+
if isinstance(item, dict) and "note" in item and "duration" in item:
|
656 |
+
note_name = item["note"]
|
657 |
+
if not is_rest(note_name):
|
658 |
+
try:
|
659 |
+
midi_note = note_name_to_midi(note_name)
|
660 |
+
notes.append(midi_note)
|
661 |
+
durations.append(item["duration"])
|
662 |
+
except ValueError:
|
663 |
+
notes.append(60) # Default to middle C if parsing fails
|
664 |
+
durations.append(item["duration"])
|
665 |
+
else:
|
666 |
+
notes.append(None) # Represent rest
|
667 |
+
durations.append(item["duration"])
|
668 |
+
|
669 |
+
# Create piano roll visualization
|
670 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
671 |
+
|
672 |
+
# Calculate time positions
|
673 |
+
time_positions = [0]
|
674 |
+
for dur in durations[:-1]:
|
675 |
+
time_positions.append(time_positions[-1] + dur)
|
676 |
+
|
677 |
+
# Plot notes as rectangles
|
678 |
+
for i, (note, dur, pos) in enumerate(zip(notes, durations, time_positions)):
|
679 |
+
if note is not None: # Skip rests
|
680 |
+
rect = plt.Rectangle((pos, note-0.4), dur, 0.8, color='blue', alpha=0.7)
|
681 |
+
ax.add_patch(rect)
|
682 |
+
# Add note name
|
683 |
+
ax.text(pos + dur/2, note+0.5, midi_to_note_name(note),
|
684 |
+
ha='center', va='bottom', fontsize=8)
|
685 |
+
|
686 |
+
# Add measure lines
|
687 |
+
numerator, denominator = map(int, time_sig.split('/'))
|
688 |
+
units_per_measure = numerator * (8 // denominator)
|
689 |
+
max_time = time_positions[-1] + durations[-1]
|
690 |
+
for measure in range(1, int(max_time / units_per_measure) + 1):
|
691 |
+
measure_pos = measure * units_per_measure
|
692 |
+
if measure_pos <= max_time:
|
693 |
+
ax.axvline(x=measure_pos, color='gray', linestyle='--', alpha=0.5)
|
694 |
+
|
695 |
+
# Set axis limits and labels
|
696 |
+
ax.set_ylim(min(notes) - 5 if None not in notes else 55,
|
697 |
+
max(notes) + 5 if None not in notes else 75)
|
698 |
+
ax.set_xlim(0, max_time)
|
699 |
+
ax.set_ylabel('MIDI Note')
|
700 |
+
ax.set_xlabel('Time (8th note units)')
|
701 |
+
ax.set_title('Exercise Visualization')
|
702 |
+
|
703 |
+
# Add piano keyboard on y-axis
|
704 |
+
ax.set_yticks([60, 62, 64, 65, 67, 69, 71, 72]) # C4 to C5
|
705 |
+
ax.set_yticklabels(['C4', 'D4', 'E4', 'F4', 'G4', 'A4', 'B4', 'C5'])
|
706 |
+
ax.grid(True, axis='y', alpha=0.3)
|
707 |
+
|
708 |
+
# Save figure to temporary file
|
709 |
+
temp_img_path = os.path.join('static', f'visualization_{uuid.uuid4().hex}.png')
|
710 |
+
plt.tight_layout()
|
711 |
+
plt.savefig(temp_img_path)
|
712 |
+
plt.close()
|
713 |
+
|
714 |
+
return temp_img_path
|
715 |
+
except Exception as e:
|
716 |
+
print(f"Error creating visualization: {e}")
|
717 |
+
return None
|
718 |
+
|
719 |
+
# Metronome function
|
720 |
+
def create_metronome_audio(tempo, time_sig, measures):
|
721 |
+
try:
|
722 |
+
numerator, denominator = map(int, time_sig.split('/'))
|
723 |
+
# Create a MIDI file with metronome clicks
|
724 |
+
mid = MidiFile(ticks_per_beat=TICKS_PER_BEAT)
|
725 |
+
track = MidiTrack()
|
726 |
+
mid.tracks.append(track)
|
727 |
+
|
728 |
+
# Add time signature and tempo
|
729 |
+
track.append(MetaMessage('time_signature', numerator=numerator,
|
730 |
+
denominator=denominator, time=0))
|
731 |
+
track.append(MetaMessage('set_tempo', tempo=mido.bpm2tempo(int(tempo)), time=0))
|
732 |
+
|
733 |
+
# Calculate total beats
|
734 |
+
beats_per_measure = numerator
|
735 |
+
total_beats = beats_per_measure * measures
|
736 |
+
|
737 |
+
# Add metronome clicks (strong beat = note 77, weak beat = note 76)
|
738 |
+
for beat in range(total_beats):
|
739 |
+
# Strong beat on first beat of measure, weak beat otherwise
|
740 |
+
note_num = 77 if beat % beats_per_measure == 0 else 76
|
741 |
+
velocity = 100 if beat % beats_per_measure == 0 else 80
|
742 |
+
|
743 |
+
# Add note on (with time=0 for first beat)
|
744 |
+
if beat == 0:
|
745 |
+
track.append(Message('note_on', note=note_num, velocity=velocity, time=0))
|
746 |
+
else:
|
747 |
+
# Each beat is a quarter note (TICKS_PER_BEAT)
|
748 |
+
track.append(Message('note_on', note=note_num, velocity=velocity, time=TICKS_PER_BEAT))
|
749 |
+
|
750 |
+
# Short duration for click
|
751 |
+
track.append(Message('note_off', note=note_num, velocity=0, time=10))
|
752 |
+
|
753 |
+
# Save and convert to audio
|
754 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".mid") as mid_file:
|
755 |
+
mid.save(mid_file.name)
|
756 |
+
wav_path = mid_file.name.replace(".mid", ".wav")
|
757 |
+
mp3_path = mid_file.name.replace(".mid", ".mp3")
|
758 |
+
|
759 |
+
# Use piano soundfont for metronome
|
760 |
+
sf2_path = get_soundfont("Piano")
|
761 |
+
try:
|
762 |
+
sp.run([
|
763 |
+
'fluidsynth', '-ni', sf2_path, mid_file.name,
|
764 |
+
'-F', wav_path, '-r', '44100', '-g', '1.0'
|
765 |
+
], check=True, capture_output=True)
|
766 |
+
except Exception:
|
767 |
+
fs = FluidSynth(sf2_path, sample_rate=44100, gain=1.0)
|
768 |
+
fs.midi_to_audio(mid_file.name, wav_path)
|
769 |
+
|
770 |
+
# Convert to MP3
|
771 |
+
sound = AudioSegment.from_wav(wav_path)
|
772 |
+
sound.export(mp3_path, format="mp3")
|
773 |
+
|
774 |
+
# Move to static directory
|
775 |
+
static_mp3_path = os.path.join('static', f'metronome_{uuid.uuid4().hex}.mp3')
|
776 |
+
shutil.move(mp3_path, static_mp3_path)
|
777 |
+
|
778 |
+
# Clean up temporary files
|
779 |
+
for f in [mid_file.name, wav_path]:
|
780 |
+
try:
|
781 |
+
os.remove(f)
|
782 |
+
except FileNotFoundError:
|
783 |
+
pass
|
784 |
+
|
785 |
+
return static_mp3_path
|
786 |
+
except Exception as e:
|
787 |
+
print(f"Error creating metronome: {e}")
|
788 |
+
return None
|
789 |
+
|
790 |
+
# Function to save exercise to library
|
791 |
+
def save_exercise_to_library(json_data, instrument, level, key, time_sig, tempo, audio_path):
|
792 |
+
try:
|
793 |
+
if not json_data or "Error" in json_data or not audio_path:
|
794 |
+
return False, "Invalid exercise data"
|
795 |
+
|
796 |
+
# Create unique ID for exercise
|
797 |
+
exercise_id = uuid.uuid4().hex
|
798 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
799 |
+
|
800 |
+
# Copy audio file to permanent storage
|
801 |
+
if audio_path and os.path.exists(audio_path):
|
802 |
+
exercise_audio_path = os.path.join("temp_audio", f"exercise_{exercise_id}.mp3")
|
803 |
+
shutil.copy(audio_path, exercise_audio_path)
|
804 |
+
else:
|
805 |
+
exercise_audio_path = ""
|
806 |
+
|
807 |
+
# Create exercise metadata
|
808 |
+
exercise_data = {
|
809 |
+
"id": exercise_id,
|
810 |
+
"timestamp": timestamp,
|
811 |
+
"instrument": instrument,
|
812 |
+
"level": level,
|
813 |
+
"key": key,
|
814 |
+
"time_signature": time_sig,
|
815 |
+
"tempo": tempo,
|
816 |
+
"json_data": json_data,
|
817 |
+
"audio_path": exercise_audio_path
|
818 |
+
}
|
819 |
+
|
820 |
+
# Save to library file
|
821 |
+
library_file = os.path.join("saved_exercises", "library.json")
|
822 |
+
|
823 |
+
# Load existing library or create new one
|
824 |
+
if os.path.exists(library_file):
|
825 |
+
try:
|
826 |
+
with open(library_file, "r") as f:
|
827 |
+
library = json.load(f)
|
828 |
+
except json.JSONDecodeError:
|
829 |
+
library = {"exercises": []}
|
830 |
+
else:
|
831 |
+
library = {"exercises": []}
|
832 |
+
|
833 |
+
# Add new exercise
|
834 |
+
library["exercises"].append(exercise_data)
|
835 |
+
|
836 |
+
# Save updated library
|
837 |
+
with open(library_file, "w") as f:
|
838 |
+
json.dump(library, f, indent=2)
|
839 |
+
|
840 |
+
return True, f"Exercise saved to library with ID: {exercise_id}"
|
841 |
+
except Exception as e:
|
842 |
+
return False, f"Error saving exercise: {str(e)}"
|
843 |
+
|
844 |
+
# Function to load exercises from library
|
845 |
+
def load_exercises_from_library():
|
846 |
+
try:
|
847 |
+
library_file = os.path.join("saved_exercises", "library.json")
|
848 |
+
if not os.path.exists(library_file):
|
849 |
+
return []
|
850 |
+
|
851 |
+
with open(library_file, "r") as f:
|
852 |
+
library = json.load(f)
|
853 |
+
|
854 |
+
return library.get("exercises", [])
|
855 |
+
except Exception as e:
|
856 |
+
print(f"Error loading library: {e}")
|
857 |
+
return []
|
858 |
+
|
859 |
+
# Function to calculate difficulty rating
|
860 |
+
def calculate_difficulty_rating(json_data, level):
|
861 |
+
try:
|
862 |
+
if not json_data or "Error" in json_data:
|
863 |
+
return 0
|
864 |
+
|
865 |
+
parsed = json.loads(json_data)
|
866 |
+
if not isinstance(parsed, list) or len(parsed) == 0:
|
867 |
+
return 0
|
868 |
+
|
869 |
+
# Extract notes and durations
|
870 |
+
notes = []
|
871 |
+
durations = []
|
872 |
+
for item in parsed:
|
873 |
+
if isinstance(item, dict) and "note" in item and "duration" in item:
|
874 |
+
note_name = item["note"]
|
875 |
+
if not is_rest(note_name):
|
876 |
+
try:
|
877 |
+
midi_note = note_name_to_midi(note_name)
|
878 |
+
notes.append(midi_note)
|
879 |
+
durations.append(item["duration"])
|
880 |
+
except ValueError:
|
881 |
+
pass
|
882 |
+
|
883 |
+
if not notes:
|
884 |
+
return 0
|
885 |
+
|
886 |
+
# Calculate difficulty factors
|
887 |
+
# 1. Range (wider range = harder)
|
888 |
+
note_range = max(notes) - min(notes) if notes else 0
|
889 |
+
range_factor = min(note_range / 12, 1.0) # Normalize to octave
|
890 |
+
|
891 |
+
# 2. Rhythmic complexity (more varied durations = harder)
|
892 |
+
unique_durations = len(set(durations))
|
893 |
+
rhythm_factor = min(unique_durations / 4, 1.0) # Normalize
|
894 |
+
|
895 |
+
# 3. Interval jumps (larger jumps = harder)
|
896 |
+
jumps = [abs(notes[i] - notes[i-1]) for i in range(1, len(notes))]
|
897 |
+
avg_jump = sum(jumps) / len(jumps) if jumps else 0
|
898 |
+
jump_factor = min(avg_jump / 7, 1.0) # Normalize to perfect fifth
|
899 |
+
|
900 |
+
# 4. Speed factor (shorter durations = harder)
|
901 |
+
avg_duration = sum(durations) / len(durations) if durations else 0
|
902 |
+
speed_factor = min(2.0 / avg_duration if avg_duration > 0 else 1.0, 1.0) # Normalize
|
903 |
+
|
904 |
+
# Calculate base difficulty
|
905 |
+
base_difficulty = (range_factor * 0.25 +
|
906 |
+
rhythm_factor * 0.25 +
|
907 |
+
jump_factor * 0.25 +
|
908 |
+
speed_factor * 0.25)
|
909 |
+
|
910 |
+
# Apply level multiplier
|
911 |
+
level_multiplier = {
|
912 |
+
"Beginner": 0.7,
|
913 |
+
"Intermediate": 1.0,
|
914 |
+
"Advanced": 1.3
|
915 |
+
}.get(level, 1.0)
|
916 |
+
|
917 |
+
# Calculate final rating (1-10 scale)
|
918 |
+
rating = round(base_difficulty * level_multiplier * 10)
|
919 |
+
return max(1, min(rating, 10)) # Ensure between 1-10
|
920 |
+
except Exception as e:
|
921 |
+
print(f"Error calculating difficulty: {e}")
|
922 |
+
return 0
|
923 |
+
|
924 |
+
# -----------------------------------------------------------------------------
|
925 |
+
# 14. Gradio user interface definition (for humans!) - ENHANCED GUI
|
926 |
# -----------------------------------------------------------------------------
|
927 |
def create_ui() -> gr.Blocks:
|
928 |
with gr.Blocks(title="Adaptive Music Exercise Generator", theme="soft") as demo:
|
929 |
+
gr.Markdown("# πΌ Adaptive Music Exercise Generator")
|
930 |
current_midi = gr.State(None)
|
931 |
current_exercise = gr.State("")
|
932 |
+
current_audio_path = gr.State(None)
|
933 |
|
934 |
+
mode = gr.Radio(["Exercise Parameters","Exercise Prompt"], value="Exercise Parameters", label="Generation Mode")
|
|
|
|
|
935 |
with gr.Row():
|
936 |
with gr.Column(scale=1):
|
937 |
with gr.Group(visible=True) as params_group:
|
938 |
gr.Markdown("### Exercise Parameters")
|
939 |
+
instrument = gr.Dropdown([
|
940 |
+
"Trumpet", "Piano", "Violin", "Clarinet", "Flute",
|
941 |
+
], value="Trumpet", label="Instrument")
|
942 |
+
level = gr.Radio([
|
943 |
+
"Beginner", "Intermediate", "Advanced",
|
944 |
+
], value="Intermediate", label="Difficulty Level")
|
945 |
+
key = gr.Dropdown([
|
946 |
+
"C Major", "G Major", "D Major", "F Major", "Bb Major", "A Minor", "E Minor",
|
947 |
+
], value="C Major", label="Key Signature")
|
948 |
+
time_signature = gr.Dropdown(["3/4", "4/4"], value="4/4", label="Time Signature")
|
949 |
+
measures = gr.Radio([4, 8, 12, 16], value=4, label="Length (measures)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
950 |
with gr.Group(visible=False) as prompt_group:
|
951 |
gr.Markdown("### Exercise Prompt")
|
952 |
+
custom_prompt = gr.Textbox("", label="Enter your custom prompt", lines=3)
|
953 |
+
measures_prompt = gr.Radio([4, 8, 12, 16], value=4, label="Length (measures)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
954 |
generate_btn = gr.Button("Generate Exercise", variant="primary")
|
|
|
955 |
with gr.Column(scale=2):
|
956 |
with gr.Tabs():
|
957 |
with gr.TabItem("Exercise Player"):
|
958 |
+
audio_output = gr.Audio(label="Generated Exercise", autoplay=True, type="filepath")
|
959 |
+
with gr.Row():
|
960 |
+
bpm_display = gr.Textbox(label="Tempo (BPM)")
|
961 |
+
time_sig_display = gr.Textbox(label="Time Signature")
|
962 |
+
duration_display = gr.Textbox(label="Audio Duration", interactive=False)
|
963 |
+
with gr.Row():
|
964 |
+
difficulty_rating = gr.Number(label="Difficulty Rating (1-10)", interactive=False, precision=1)
|
965 |
+
save_btn = gr.Button("Save to Library", variant="secondary")
|
966 |
+
|
967 |
+
# Metronome section
|
968 |
+
gr.Markdown("### Metronome")
|
969 |
+
with gr.Row():
|
970 |
+
metronome_tempo = gr.Slider(minimum=40, maximum=200, value=60, step=1, label="Metronome Tempo")
|
971 |
+
metronome_btn = gr.Button("Generate Metronome", variant="secondary")
|
972 |
+
metronome_audio = gr.Audio(label="Metronome", type="filepath")
|
973 |
+
|
974 |
with gr.TabItem("Exercise Data"):
|
975 |
+
json_output = gr.Code(label="JSON Representation", language="json")
|
|
|
|
|
|
|
|
|
|
|
976 |
duration_sum = gr.Number(
|
977 |
+
label="Total Duration Units (8th notes)",
|
978 |
interactive=False,
|
979 |
precision=0
|
980 |
)
|
981 |
+
|
982 |
+
with gr.TabItem("Visualization"):
|
983 |
+
visualization_output = gr.Image(label="Exercise Visualization", type="filepath")
|
984 |
+
visualize_btn = gr.Button("Generate Visualization", variant="secondary")
|
985 |
+
|
986 |
with gr.TabItem("MIDI Export"):
|
987 |
midi_output = gr.File(label="MIDI File")
|
988 |
download_midi = gr.Button("Generate MIDI File")
|
989 |
+
|
990 |
+
with gr.TabItem("Exercise Library"):
|
991 |
+
refresh_library_btn = gr.Button("Refresh Library", variant="secondary")
|
992 |
+
library_dropdown = gr.Dropdown([], label="Saved Exercises", interactive=True)
|
993 |
+
load_exercise_btn = gr.Button("Load Selected Exercise", variant="secondary")
|
994 |
+
|
995 |
with gr.TabItem("AI Chat"):
|
996 |
+
chat_history = gr.Chatbot(label="Practice Assistant", height=400)
|
997 |
+
chat_message = gr.Textbox(label="Ask the AI anything about your practice")
|
|
|
|
|
|
|
998 |
send_chat_btn = gr.Button("Send")
|
999 |
+
# Toggle UI groups
|
|
|
1000 |
mode.change(
|
1001 |
fn=lambda m: {
|
1002 |
params_group: gr.update(visible=(m == "Exercise Parameters")),
|
1003 |
prompt_group: gr.update(visible=(m == "Exercise Prompt")),
|
1004 |
},
|
1005 |
+
inputs=[mode], outputs=[params_group, prompt_group]
|
|
|
1006 |
)
|
|
|
|
|
1007 |
def generate_caller(mode_val, instrument_val, level_val, key_val,
|
1008 |
+
time_sig_val, measures_val, prompt_val, measures_prompt_val):
|
1009 |
real_measures = measures_prompt_val if mode_val == "Exercise Prompt" else measures_val
|
1010 |
+
fixed_tempo = 60
|
1011 |
+
json_data, mp3_path, tempo, midi, duration, time_sig, total_duration = generate_exercise(
|
1012 |
+
instrument_val, level_val, key_val, fixed_tempo, time_sig_val,
|
1013 |
+
real_measures, prompt_val, mode_val
|
1014 |
)
|
1015 |
+
|
1016 |
+
# Calculate difficulty rating
|
1017 |
+
rating = calculate_difficulty_rating(json_data, level_val)
|
1018 |
+
|
1019 |
+
# Generate visualization
|
1020 |
+
viz_path = create_visualization(json_data, time_sig_val)
|
1021 |
+
|
1022 |
+
return json_data, mp3_path, tempo, midi, duration, time_sig, total_duration, rating, viz_path, mp3_path
|
1023 |
+
|
1024 |
generate_btn.click(
|
1025 |
fn=generate_caller,
|
1026 |
+
inputs=[mode, instrument, level, key, time_signature, measures, custom_prompt, measures_prompt],
|
1027 |
+
outputs=[json_output, audio_output, bpm_display, current_midi, duration_display,
|
1028 |
+
time_sig_display, duration_sum, difficulty_rating, visualization_output, current_audio_path]
|
|
|
1029 |
)
|
1030 |
|
1031 |
+
# Visualization button
|
1032 |
+
visualize_btn.click(
|
1033 |
+
fn=create_visualization,
|
1034 |
+
inputs=[json_output, time_signature],
|
1035 |
+
outputs=[visualization_output]
|
1036 |
+
)
|
1037 |
+
|
1038 |
+
# Metronome generation
|
1039 |
+
def generate_metronome(tempo, time_sig, measures_val):
|
1040 |
+
return create_metronome_audio(tempo, time_sig, measures_val)
|
1041 |
+
|
1042 |
+
metronome_btn.click(
|
1043 |
+
fn=generate_metronome,
|
1044 |
+
inputs=[metronome_tempo, time_signature, measures],
|
1045 |
+
outputs=[metronome_audio]
|
1046 |
+
)
|
1047 |
+
|
1048 |
+
# Save to library function
|
1049 |
+
def save_to_library(json_data, instrument_val, level_val, key_val, time_sig_val, tempo_val, audio_path):
|
1050 |
+
save_exercise_to_library(
|
1051 |
+
json_data, instrument_val, level_val, key_val, time_sig_val, tempo_val, audio_path
|
1052 |
+
)
|
1053 |
+
return
|
1054 |
+
|
1055 |
+
save_btn.click(
|
1056 |
+
fn=save_to_library,
|
1057 |
+
inputs=[json_output, instrument, level, key, time_signature, bpm_display, current_audio_path],
|
1058 |
+
outputs=[]
|
1059 |
+
)
|
1060 |
+
|
1061 |
+
# Library functions
|
1062 |
+
def refresh_library():
|
1063 |
+
exercises = load_exercises_from_library()
|
1064 |
+
options = [f"{ex['timestamp']} - {ex['instrument']} ({ex['level']}) - {ex['key']} {ex['time_signature']}"
|
1065 |
+
for ex in exercises]
|
1066 |
+
return gr.Dropdown.update(choices=options, value=options[0] if options else None)
|
1067 |
+
|
1068 |
+
refresh_library_btn.click(
|
1069 |
+
fn=refresh_library,
|
1070 |
+
inputs=[],
|
1071 |
+
outputs=[library_dropdown]
|
1072 |
+
)
|
1073 |
+
|
1074 |
+
def load_exercise_from_library(selected_exercise):
|
1075 |
+
if not selected_exercise:
|
1076 |
+
return None, None, None, None, None, None, 0, None
|
1077 |
|
1078 |
+
exercises = load_exercises_from_library()
|
1079 |
+
for i, ex in enumerate(exercises):
|
1080 |
+
option = f"{ex['timestamp']} - {ex['instrument']} ({ex['level']}) - {ex['key']} {ex['time_signature']}"
|
1081 |
+
if option == selected_exercise:
|
1082 |
+
# Load the exercise data
|
1083 |
+
json_data = ex['json_data']
|
1084 |
+
audio_path = ex['audio_path']
|
1085 |
+
tempo = ex['tempo']
|
1086 |
+
time_sig = ex['time_signature']
|
1087 |
+
|
1088 |
+
# Calculate duration sum
|
1089 |
+
try:
|
1090 |
+
parsed = json.loads(json_data)
|
1091 |
+
total_duration = sum(item['duration'] for item in parsed if isinstance(item, dict))
|
1092 |
+
except:
|
1093 |
+
total_duration = 0
|
1094 |
+
|
1095 |
+
# Calculate difficulty rating
|
1096 |
+
rating = calculate_difficulty_rating(json_data, ex['level'])
|
1097 |
+
|
1098 |
+
# Generate visualization
|
1099 |
+
viz_path = create_visualization(json_data, time_sig)
|
1100 |
+
|
1101 |
+
# Calculate audio duration
|
1102 |
+
try:
|
1103 |
+
audio = AudioSegment.from_file(audio_path)
|
1104 |
+
duration = f"{audio.duration_seconds:.2f} seconds"
|
1105 |
+
except:
|
1106 |
+
duration = "Unknown"
|
1107 |
+
|
1108 |
+
return json_data, audio_path, tempo, duration, time_sig, total_duration, rating, viz_path
|
1109 |
|
1110 |
+
return None, None, None, None, None, None, 0, None
|
|
|
1111 |
|
1112 |
+
load_exercise_btn.click(
|
1113 |
+
fn=load_exercise_from_library,
|
1114 |
+
inputs=[library_dropdown],
|
1115 |
+
outputs=[json_output, audio_output, bpm_display, duration_display,
|
1116 |
+
time_sig_display, duration_sum, difficulty_rating, visualization_output]
|
1117 |
+
)
|
1118 |
|
1119 |
+
def save_midi(json_data, instr, time_sig, key_sig="C Major"):
|
1120 |
+
try:
|
1121 |
+
if not json_data or "Error" in json_data:
|
1122 |
+
return None
|
1123 |
+
|
1124 |
+
parsed = json.loads(json_data)
|
1125 |
+
|
1126 |
+
# Validate JSON structure
|
1127 |
+
if not isinstance(parsed, list):
|
1128 |
+
return None
|
1129 |
+
|
1130 |
+
old_format = []
|
1131 |
+
for item in parsed:
|
1132 |
+
if isinstance(item, dict) and "note" in item and "duration" in item:
|
1133 |
+
old_format.append([item["note"], item["duration"]])
|
1134 |
+
|
1135 |
+
if not old_format:
|
1136 |
+
return None
|
1137 |
+
|
1138 |
+
# Calculate total units
|
1139 |
+
total_units = sum(d[1] for d in old_format)
|
1140 |
+
numerator, denominator = map(int, time_sig.split('/'))
|
1141 |
+
units_per_measure = numerator * (8 // denominator)
|
1142 |
+
measures_est = max(1, round(total_units / units_per_measure))
|
1143 |
+
|
1144 |
+
# Generate MIDI
|
1145 |
+
cumulative = 0
|
1146 |
+
scaled_new = []
|
1147 |
+
for note, dur in old_format:
|
1148 |
+
cumulative += dur
|
1149 |
+
scaled_new.append({
|
1150 |
+
"note": note,
|
1151 |
+
"duration": dur,
|
1152 |
+
"cumulative_duration": cumulative
|
1153 |
+
})
|
1154 |
+
|
1155 |
+
midi_obj = json_to_midi(scaled_new, instr, 60, time_sig, measures_est, key=key_sig)
|
1156 |
+
midi_path = os.path.join("static", "exercise.mid")
|
1157 |
+
midi_obj.save(midi_path)
|
1158 |
+
return midi_path
|
1159 |
+
except Exception as e:
|
1160 |
+
print(f"Error saving MIDI: {e}")
|
1161 |
+
return None
|
1162 |
+
|
1163 |
download_midi.click(
|
1164 |
fn=save_midi,
|
1165 |
+
inputs=[json_output, instrument, time_signature, key],
|
1166 |
+
outputs=[midi_output],
|
1167 |
)
|
|
|
|
|
1168 |
send_chat_btn.click(
|
1169 |
fn=handle_chat,
|
1170 |
inputs=[chat_message, chat_history, instrument, level],
|
1171 |
+
outputs=[chat_message, chat_history],
|
1172 |
)
|
|
|
1173 |
return demo
|
1174 |
|
1175 |
# -----------------------------------------------------------------------------
|
1176 |
+
# 14. Entry point
|
1177 |
# -----------------------------------------------------------------------------
|
1178 |
if __name__ == "__main__":
|
1179 |
demo = create_ui()
|