|
import os |
|
import gradio as gr |
|
import torch |
|
from typing import Tuple, List, Optional |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import note_seq |
|
from matplotlib.figure import Figure |
|
from numpy import ndarray |
|
from note_seq.protobuf.music_pb2 import NoteSequence |
|
from note_seq.constants import STANDARD_PPQ |
|
import logging |
|
import math |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
SAMPLE_RATE = 44100 |
|
GM_INSTRUMENTS = [ |
|
"Acoustic Grand Piano", |
|
"Bright Acoustic Piano", |
|
"Electric Grand Piano", |
|
"Honky-tonk Piano", |
|
"Electric Piano 1", |
|
"Electric Piano 2", |
|
"Harpsichord", |
|
"Clavi", |
|
"Celesta", |
|
"Glockenspiel", |
|
"Music Box", |
|
"Vibraphone", |
|
"Marimba", |
|
"Xylophone", |
|
"Tubular Bells", |
|
"Dulcimer", |
|
"Drawbar Organ", |
|
"Percussive Organ", |
|
"Rock Organ", |
|
"Church Organ", |
|
"Reed Organ", |
|
"Accordion", |
|
"Harmonica", |
|
"Tango Accordion", |
|
"Acoustic Guitar (nylon)", |
|
"Acoustic Guitar (steel)", |
|
"Electric Guitar (jazz)", |
|
"Electric Guitar (clean)", |
|
"Electric Guitar (muted)", |
|
"Overdriven Guitar", |
|
"Distortion Guitar", |
|
"Guitar Harmonics", |
|
"Acoustic Bass", |
|
"Electric Bass (finger)", |
|
"Electric Bass (pick)", |
|
"Fretless Bass", |
|
"Slap Bass 1", |
|
"Slap Bass 2", |
|
"Synth Bass 1", |
|
"Synth Bass 2", |
|
"Violin", |
|
"Viola", |
|
"Cello", |
|
"Contrabass", |
|
"Tremolo Strings", |
|
"Pizzicato Strings", |
|
"Orchestral Harp", |
|
"Timpani", |
|
"String Ensemble 1", |
|
"String Ensemble 2", |
|
"Synth Strings 1", |
|
"Synth Strings 2", |
|
"Choir Aahs", |
|
"Voice Oohs", |
|
"Synth Choir", |
|
"Orchestra Hit", |
|
"Trumpet", |
|
"Trombone", |
|
"Tuba", |
|
"Muted Trumpet", |
|
"French Horn", |
|
"Brass Section", |
|
"Synth Brass 1", |
|
"Synth Brass 2", |
|
"Soprano Sax", |
|
"Alto Sax", |
|
"Tenor Sax", |
|
"Baritone Sax", |
|
"Oboe", |
|
"English Horn", |
|
"Bassoon", |
|
"Clarinet", |
|
"Piccolo", |
|
"Flute", |
|
"Recorder", |
|
"Pan Flute", |
|
"Blown Bottle", |
|
"Shakuhachi", |
|
"Whistle", |
|
"Ocarina", |
|
"Lead 1 (square)", |
|
"Lead 2 (sawtooth)", |
|
"Lead 3 (calliope)", |
|
"Lead 4 (chiff)", |
|
"Lead 5 (charang)", |
|
"Lead 6 (voice)", |
|
"Lead 7 (fifths)", |
|
"Lead 8 (bass + lead)", |
|
"Pad 1 (new age)", |
|
"Pad 2 (warm)", |
|
"Pad 3 (polysynth)", |
|
"Pad 4 (choir)", |
|
"Pad 5 (bowed)", |
|
"Pad 6 (metallic)", |
|
"Pad 7 (halo)", |
|
"Pad 8 (sweep)", |
|
"FX 1 (rain)", |
|
"FX 2 (soundtrack)", |
|
"FX 3 (crystal)", |
|
"FX 4 (atmosphere)", |
|
"FX 5 (brightness)", |
|
"FX 6 (goblins)", |
|
"FX 7 (echoes)", |
|
"FX 8 (sci-fi)", |
|
"Sitar", |
|
"Banjo", |
|
"Shamisen", |
|
"Koto", |
|
"Kalimba", |
|
"Bagpipe", |
|
"Fiddle", |
|
"Shanai", |
|
"Tinkle Bell", |
|
"Agogo", |
|
"Steel Drums", |
|
"Woodblock", |
|
"Taiko Drum", |
|
"Melodic Tom", |
|
"Synth Drum", |
|
"Reverse Cymbal", |
|
"Guitar Fret Noise", |
|
"Breath Noise", |
|
"Seashore", |
|
"Bird Tweet", |
|
"Telephone Ring", |
|
"Helicopter", |
|
"Applause", |
|
"Gunshot", |
|
] |
|
tokenizer = None |
|
model = None |
|
|
|
|
|
def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]: |
|
logging.info("get_model_and_tokenizer: Starting to load model and tokenizer...") |
|
global model, tokenizer |
|
if model is None or tokenizer is None: |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
logging.info(f"get_model_and_tokenizer: Using device: {device}") |
|
tokenizer = AutoTokenizer.from_pretrained("juancopi81/lmd_8bars_tokenizer") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"juancopi81/lmd-8bars-2048-epochs40_v4" |
|
) |
|
model = model.to(device) |
|
logging.info("get_model_and_tokenizer: Model and tokenizer loaded successfully.") |
|
else: |
|
logging.info("get_model_and_tokenizer: Model and tokenizer already loaded.") |
|
return model, tokenizer |
|
def token_sequence_to_note_sequence( |
|
token_sequence: str, |
|
qpm: float = 120.0, |
|
use_program: bool = True, |
|
use_drums: bool = True, |
|
instrument_mapper: Optional[dict] = None, |
|
only_piano: bool = False, |
|
) -> NoteSequence: |
|
logging.info(f"token_sequence_to_note_sequence: Starting conversion. QPM: {qpm}, use_program: {use_program}, use_drums: {use_drums}, only_piano: {only_piano}") |
|
if isinstance(token_sequence, str): |
|
token_sequence = token_sequence.split() |
|
note_sequence = empty_note_sequence(qpm) |
|
note_length_16th = 0.25 * 60 / qpm |
|
bar_length = 4.0 * 60 / qpm |
|
current_program = 1 |
|
current_is_drum = False |
|
current_instrument = 0 |
|
track_count = 0 |
|
for _, token in enumerate(token_sequence): |
|
if token == "PIECE_START": |
|
pass |
|
elif token == "PIECE_END": |
|
break |
|
elif token == "TRACK_START": |
|
current_bar_index = 0 |
|
track_count += 1 |
|
pass |
|
elif token == "TRACK_END": |
|
pass |
|
elif token == "KEYS_START": |
|
pass |
|
elif token == "KEYS_END": |
|
pass |
|
elif token.startswith("KEY="): |
|
pass |
|
elif token.startswith("INST"): |
|
instrument = token.split("=")[-1] |
|
if instrument != "DRUMS" and use_program: |
|
if instrument_mapper is not None: |
|
if instrument in instrument_mapper: |
|
instrument = instrument_mapper[instrument] |
|
current_program = int(instrument) |
|
current_instrument = track_count |
|
current_is_drum = False |
|
if instrument == "DRUMS" and use_drums: |
|
current_instrument = 0 |
|
current_program = 0 |
|
current_is_drum = True |
|
elif token == "BAR_START": |
|
current_time = current_bar_index * bar_length |
|
current_notes = {} |
|
elif token == "BAR_END": |
|
current_bar_index += 1 |
|
pass |
|
elif token.startswith("NOTE_ON"): |
|
pitch = int(token.split("=")[-1]) |
|
note = note_sequence.notes.add() |
|
note.start_time = current_time |
|
note.end_time = current_time + 4 * note_length_16th |
|
note.pitch = pitch |
|
note.instrument = current_instrument |
|
note.program = current_program |
|
note.velocity = 80 |
|
note.is_drum = current_is_drum |
|
current_notes[pitch] = note |
|
elif token.startswith("NOTE_OFF"): |
|
pitch = int(token.split("=")[-1]) |
|
if pitch in current_notes: |
|
note = current_notes[pitch] |
|
note.end_time = current_time |
|
elif token.startswith("TIME_DELTA"): |
|
delta = float(token.split("=")[-1]) * note_length_16th |
|
current_time += delta |
|
elif token.startswith("DENSITY="): |
|
pass |
|
elif token == "[PAD]": |
|
pass |
|
else: |
|
pass |
|
instruments_drums = [] |
|
for note in note_sequence.notes: |
|
pair = [note.program, note.is_drum] |
|
if pair not in instruments_drums: |
|
instruments_drums += [pair] |
|
note.instrument = instruments_drums.index(pair) |
|
if only_piano: |
|
for note in note_sequence.notes: |
|
if not note.is_drum: |
|
note.instrument = 0 |
|
note.program = 0 |
|
logging.info("token_sequence_to_note_sequence: Conversion to note sequence complete.") |
|
return note_sequence |
|
def empty_note_sequence(qpm: float = 120.0, total_time: float = 0.0) -> NoteSequence: |
|
note_sequence = NoteSequence() |
|
note_sequence.tempos.add().qpm = qpm |
|
note_sequence.ticks_per_quarter = STANDARD_PPQ |
|
note_sequence.total_time = total_time |
|
return note_sequence |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model, tokenizer = get_model_and_tokenizer() |
|
def create_seed_string(genre: str = "OTHER", prompt: str = "") -> str: |
|
logging.info(f"create_seed_string: Creating seed string. Genre: {genre}, Prompt: '{prompt}'") |
|
if prompt: |
|
seed_string = f"PIECE_START PROMPT={prompt} GENRE={genre} TRACK_START" |
|
elif genre == "RANDOM": |
|
seed_string = "PIECE_START" |
|
else: |
|
seed_string = f"PIECE_START GENRE={genre} TRACK_START" |
|
logging.info(f"create_seed_string: Seed string created: '{seed_string}'") |
|
return seed_string |
|
def get_instruments(text_sequence: str) -> List[str]: |
|
instruments = [] |
|
parts = text_sequence.split() |
|
for part in parts: |
|
if part.startswith("INST="): |
|
if part[5:] == "DRUMS": |
|
instruments.append("Drums") |
|
else: |
|
index = int(part[5:]) |
|
instruments.append(GM_INSTRUMENTS[index]) |
|
return instruments |
|
def generate_new_instrument(seed: str, temp: float = 0.85, max_tokens=512) -> str: |
|
logging.info(f"generate_new_instrument: Starting instrument generation. Seed: '{seed}', Temperature: {temp}, Max Tokens: {max_tokens}") |
|
seed_length = len(tokenizer.encode(seed)) |
|
input_ids = tokenizer.encode(seed, return_tensors="pt").to(model.device) |
|
eos_token_id = tokenizer.encode("TRACK_END")[0] |
|
generated_ids = model.generate( |
|
input_ids, |
|
max_new_tokens=max_tokens, |
|
do_sample=True, |
|
temperature=temp, |
|
eos_token_id=eos_token_id, |
|
) |
|
generated_sequence = tokenizer.decode(generated_ids[0]) |
|
new_generated_sequence = tokenizer.decode(generated_ids[0][seed_length:]) |
|
logging.info(f"generate_new_instrument: Generated sequence: '{new_generated_sequence}'") |
|
if "NOTE_ON" in new_generated_sequence: |
|
logging.info("generate_new_instrument: New instrument generated successfully.") |
|
return generated_sequence |
|
else: |
|
logging.warning("generate_new_instrument: No NOTE_ON token found in generated sequence after seed. Generation may be incomplete.") |
|
return "" |
|
|
|
def get_outputs_from_string( |
|
generated_sequence: str, qpm: int = 120 |
|
) -> Tuple[ndarray, str, Figure, str, str]: |
|
logging.info(f"get_outputs_from_string: Starting output generation. QPM: {qpm}") |
|
instruments = get_instruments(generated_sequence) |
|
instruments_str = "\n".join(f"- {instrument}" for instrument in instruments) |
|
note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm) |
|
|
|
if not note_sequence.notes: |
|
logging.warning("get_outputs_from_string: Note sequence is empty, skipping plot.") |
|
fig = None |
|
else: |
|
fig = note_seq.plot_sequence(note_sequence, show_figure=False) |
|
|
|
synth = note_seq.fluidsynth |
|
array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE) |
|
int16_data = note_seq.audio_io.float_samples_to_int16(array_of_floats) |
|
num_tokens = str(len(generated_sequence.split())) |
|
audio = gr.make_waveform((SAMPLE_RATE, int16_data)) |
|
note_seq.note_sequence_to_midi_file(note_sequence, "midi_ouput.mid") |
|
logging.info("get_outputs_from_string: Output generation complete.") |
|
return audio, "midi_ouput.mid", fig, instruments_str, num_tokens |
|
|
|
def remove_last_instrument( |
|
text_sequence: str, qpm: int = 120 |
|
) -> Tuple[ndarray, str, Figure, str, str, str]: |
|
logging.info(f"remove_last_instrument: Removing last instrument. QPM: {qpm}") |
|
tracks = text_sequence.split("TRACK_START") |
|
modified_tracks = tracks[:-1] |
|
new_song = "TRACK_START".join(modified_tracks) |
|
if len(tracks) == 2: |
|
audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song( |
|
text_sequence=new_song, qpm=qpm, duration=1 |
|
) |
|
elif len(tracks) == 1: |
|
audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song( |
|
text_sequence="", qpm=qpm, duration=1 |
|
) |
|
else: |
|
audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string( |
|
new_song, qpm |
|
) |
|
logging.info("remove_last_instrument: Last instrument removed.") |
|
return audio, midi_file, fig, instruments_str, new_song, num_tokens |
|
def regenerate_last_instrument( |
|
text_sequence: str, qpm: int = 120 |
|
) -> Tuple[ndarray, str, Figure, str, str, str]: |
|
logging.info(f"regenerate_last_instrument: Regenerating last instrument. QPM: {qpm}") |
|
last_inst_index = text_sequence.rfind("INST=") |
|
if last_inst_index == -1: |
|
audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song( |
|
text_sequence="", qpm=qpm, duration=1 |
|
) |
|
else: |
|
next_space_index = text_sequence.find(" ", last_inst_index) |
|
new_seed = text_sequence[:next_space_index] |
|
audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song( |
|
text_sequence=new_seed, qpm=qpm, duration=1 |
|
) |
|
logging.info("regenerate_last_instrument: Last instrument regenerated.") |
|
return audio, midi_file, fig, instruments_str, new_song, num_tokens |
|
def change_tempo( |
|
text_sequence: str, qpm: int |
|
) -> Tuple[ndarray, str, Figure, str, str, str]: |
|
logging.info(f"change_tempo: Changing tempo to {qpm} QPM.") |
|
audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string( |
|
text_sequence, qpm=qpm |
|
) |
|
logging.info(f"change_tempo: Tempo changed to {qpm} QPM.") |
|
return audio, midi_file, fig, instruments_str, text_sequence, num_tokens |
|
def generate_song( |
|
genre: str = "OTHER", |
|
temp: float = 0.85, |
|
text_sequence: str = "", |
|
qpm: int = 120, |
|
prompt: str = "", |
|
duration: int = 30 |
|
) -> Tuple[ndarray, str, Figure, str, str, str]: |
|
logging.info(f"generate_song: Starting song generation. Genre: {genre}, Temperature: {temp}, QPM: {qpm}, Duration: {duration} seconds, Prompt: '{prompt}'") |
|
if text_sequence == "": |
|
seed_string = create_seed_string(genre, prompt) |
|
else: |
|
seed_string = text_sequence |
|
|
|
num_tracks = max(1, int(math.ceil(duration / 17))) |
|
|
|
generated_sequence = seed_string |
|
for _ in range(num_tracks): |
|
instrument_sequence = generate_new_instrument(seed=generated_sequence, temp=temp) |
|
if instrument_sequence: |
|
generated_sequence = instrument_sequence |
|
else: |
|
logging.warning("generate_song: Instrument generation failed, stopping track generation early.") |
|
break |
|
|
|
audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string( |
|
generated_sequence, qpm |
|
) |
|
logging.info("generate_song: Song generation complete.") |
|
return audio, midi_file, fig, instruments_str, generated_sequence, num_tokens |
|
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" |
|
genres = ["ROCK", "POP", "OTHER", "R&B/SOUL", "JAZZ", "ELECTRONIC", "RANDOM"] |
|
demo = gr.Blocks() |
|
def run(): |
|
with demo: |
|
gr.DuplicateButton(value="Duplicate Space for private use") |
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt_text = gr.Textbox(lines=2, placeholder="Enter text prompt here...", label="Text Prompt (Optional)") |
|
duration_slider = gr.Slider(minimum=1, maximum=1000, step=1, value=30, label="Duration (Seconds)") |
|
temp = gr.Slider( |
|
minimum=0, maximum=1, step=0.05, value=0.85, label="Temperature" |
|
) |
|
genre = gr.Dropdown( |
|
choices=genres, value="POP", label="Select Genre" |
|
) |
|
with gr.Row(): |
|
btn_from_scratch = gr.Button("🧹 Start from scratch") |
|
btn_continue = gr.Button("➡️ Continue Generation") |
|
btn_remove_last = gr.Button("↩️ Remove last instrument") |
|
btn_regenerate_last = gr.Button("🔄 Regenerate last instrument") |
|
with gr.Column(): |
|
with gr.Group(): |
|
audio_output = gr.Video(show_share_button=True) |
|
midi_file = gr.File() |
|
with gr.Row(): |
|
qpm = gr.Slider( |
|
minimum=60, maximum=140, step=10, value=120, label="Tempo" |
|
) |
|
btn_qpm = gr.Button("Change Tempo") |
|
with gr.Row(): |
|
with gr.Column(): |
|
plot_output = gr.Plot() |
|
with gr.Column(): |
|
instruments_output = gr.Markdown("# List of generated instruments") |
|
with gr.Row(): |
|
text_sequence = gr.Text() |
|
empty_sequence = gr.Text(visible=False) |
|
with gr.Row(): |
|
num_tokens = gr.Text(visible=False) |
|
btn_from_scratch.click( |
|
fn=generate_song, |
|
inputs=[genre, temp, empty_sequence, qpm, prompt_text, duration_slider], |
|
outputs=[ |
|
audio_output, |
|
midi_file, |
|
plot_output, |
|
instruments_output, |
|
text_sequence, |
|
num_tokens, |
|
], |
|
api_name="generate_song_scratch" |
|
) |
|
btn_continue.click( |
|
fn=generate_song, |
|
inputs=[genre, temp, text_sequence, qpm, prompt_text, duration_slider], |
|
outputs=[ |
|
audio_output, |
|
midi_file, |
|
plot_output, |
|
instruments_output, |
|
text_sequence, |
|
num_tokens, |
|
], |
|
api_name="generate_song_continue" |
|
) |
|
btn_remove_last.click( |
|
fn=remove_last_instrument, |
|
inputs=[text_sequence, qpm], |
|
outputs=[ |
|
audio_output, |
|
midi_file, |
|
plot_output, |
|
instruments_output, |
|
text_sequence, |
|
num_tokens, |
|
], |
|
api_name="remove_last_instrument" |
|
) |
|
btn_regenerate_last.click( |
|
fn=regenerate_last_instrument, |
|
inputs=[text_sequence, qpm], |
|
outputs=[ |
|
audio_output, |
|
midi_file, |
|
plot_output, |
|
instruments_output, |
|
text_sequence, |
|
num_tokens, |
|
], |
|
api_name="regenerate_last_instrument" |
|
) |
|
btn_qpm.click( |
|
fn=change_tempo, |
|
inputs=[text_sequence, qpm], |
|
outputs=[ |
|
audio_output, |
|
midi_file, |
|
plot_output, |
|
instruments_output, |
|
text_sequence, |
|
num_tokens, |
|
], |
|
api_name="change_tempo" |
|
) |
|
demo.queue().launch(server_name="0.0.0.0", server_port=7860) |
|
if __name__ == "__main__": |
|
run() |