import os import gradio as gr import torch from typing import Tuple, List, Optional from transformers import AutoTokenizer, AutoModelForCausalLM import note_seq from matplotlib.figure import Figure from numpy import ndarray from note_seq.protobuf.music_pb2 import NoteSequence from note_seq.constants import STANDARD_PPQ import logging import math logging.basicConfig(level=logging.INFO) SAMPLE_RATE = 44100 GM_INSTRUMENTS = [ "Acoustic Grand Piano", "Bright Acoustic Piano", "Electric Grand Piano", "Honky-tonk Piano", "Electric Piano 1", "Electric Piano 2", "Harpsichord", "Clavi", "Celesta", "Glockenspiel", "Music Box", "Vibraphone", "Marimba", "Xylophone", "Tubular Bells", "Dulcimer", "Drawbar Organ", "Percussive Organ", "Rock Organ", "Church Organ", "Reed Organ", "Accordion", "Harmonica", "Tango Accordion", "Acoustic Guitar (nylon)", "Acoustic Guitar (steel)", "Electric Guitar (jazz)", "Electric Guitar (clean)", "Electric Guitar (muted)", "Overdriven Guitar", "Distortion Guitar", "Guitar Harmonics", "Acoustic Bass", "Electric Bass (finger)", "Electric Bass (pick)", "Fretless Bass", "Slap Bass 1", "Slap Bass 2", "Synth Bass 1", "Synth Bass 2", "Violin", "Viola", "Cello", "Contrabass", "Tremolo Strings", "Pizzicato Strings", "Orchestral Harp", "Timpani", "String Ensemble 1", "String Ensemble 2", "Synth Strings 1", "Synth Strings 2", "Choir Aahs", "Voice Oohs", "Synth Choir", "Orchestra Hit", "Trumpet", "Trombone", "Tuba", "Muted Trumpet", "French Horn", "Brass Section", "Synth Brass 1", "Synth Brass 2", "Soprano Sax", "Alto Sax", "Tenor Sax", "Baritone Sax", "Oboe", "English Horn", "Bassoon", "Clarinet", "Piccolo", "Flute", "Recorder", "Pan Flute", "Blown Bottle", "Shakuhachi", "Whistle", "Ocarina", "Lead 1 (square)", "Lead 2 (sawtooth)", "Lead 3 (calliope)", "Lead 4 (chiff)", "Lead 5 (charang)", "Lead 6 (voice)", "Lead 7 (fifths)", "Lead 8 (bass + lead)", "Pad 1 (new age)", "Pad 2 (warm)", "Pad 3 (polysynth)", "Pad 4 (choir)", "Pad 5 (bowed)", "Pad 6 (metallic)", "Pad 7 (halo)", "Pad 8 (sweep)", "FX 1 (rain)", "FX 2 (soundtrack)", "FX 3 (crystal)", "FX 4 (atmosphere)", "FX 5 (brightness)", "FX 6 (goblins)", "FX 7 (echoes)", "FX 8 (sci-fi)", "Sitar", "Banjo", "Shamisen", "Koto", "Kalimba", "Bagpipe", "Fiddle", "Shanai", "Tinkle Bell", "Agogo", "Steel Drums", "Woodblock", "Taiko Drum", "Melodic Tom", "Synth Drum", "Reverse Cymbal", "Guitar Fret Noise", "Breath Noise", "Seashore", "Bird Tweet", "Telephone Ring", "Helicopter", "Applause", "Gunshot", ] tokenizer = None model = None def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]: logging.info("get_model_and_tokenizer: Starting to load model and tokenizer...") global model, tokenizer if model is None or tokenizer is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logging.info(f"get_model_and_tokenizer: Using device: {device}") tokenizer = AutoTokenizer.from_pretrained("juancopi81/lmd_8bars_tokenizer") model = AutoModelForCausalLM.from_pretrained( "juancopi81/lmd-8bars-2048-epochs40_v4" ) model = model.to(device) logging.info("get_model_and_tokenizer: Model and tokenizer loaded successfully.") else: logging.info("get_model_and_tokenizer: Model and tokenizer already loaded.") return model, tokenizer def token_sequence_to_note_sequence( token_sequence: str, qpm: float = 120.0, use_program: bool = True, use_drums: bool = True, instrument_mapper: Optional[dict] = None, only_piano: bool = False, ) -> NoteSequence: logging.info(f"token_sequence_to_note_sequence: Starting conversion. QPM: {qpm}, use_program: {use_program}, use_drums: {use_drums}, only_piano: {only_piano}") if isinstance(token_sequence, str): token_sequence = token_sequence.split() note_sequence = empty_note_sequence(qpm) note_length_16th = 0.25 * 60 / qpm bar_length = 4.0 * 60 / qpm current_program = 1 current_is_drum = False current_instrument = 0 track_count = 0 for _, token in enumerate(token_sequence): if token == "PIECE_START": pass elif token == "PIECE_END": break elif token == "TRACK_START": current_bar_index = 0 track_count += 1 pass elif token == "TRACK_END": pass elif token == "KEYS_START": pass elif token == "KEYS_END": pass elif token.startswith("KEY="): pass elif token.startswith("INST"): instrument = token.split("=")[-1] if instrument != "DRUMS" and use_program: if instrument_mapper is not None: if instrument in instrument_mapper: instrument = instrument_mapper[instrument] current_program = int(instrument) current_instrument = track_count current_is_drum = False if instrument == "DRUMS" and use_drums: current_instrument = 0 current_program = 0 current_is_drum = True elif token == "BAR_START": current_time = current_bar_index * bar_length current_notes = {} elif token == "BAR_END": current_bar_index += 1 pass elif token.startswith("NOTE_ON"): pitch = int(token.split("=")[-1]) note = note_sequence.notes.add() note.start_time = current_time note.end_time = current_time + 4 * note_length_16th note.pitch = pitch note.instrument = current_instrument note.program = current_program note.velocity = 80 note.is_drum = current_is_drum current_notes[pitch] = note elif token.startswith("NOTE_OFF"): pitch = int(token.split("=")[-1]) if pitch in current_notes: note = current_notes[pitch] note.end_time = current_time elif token.startswith("TIME_DELTA"): delta = float(token.split("=")[-1]) * note_length_16th current_time += delta elif token.startswith("DENSITY="): pass elif token == "[PAD]": pass else: pass instruments_drums = [] for note in note_sequence.notes: pair = [note.program, note.is_drum] if pair not in instruments_drums: instruments_drums += [pair] note.instrument = instruments_drums.index(pair) if only_piano: for note in note_sequence.notes: if not note.is_drum: note.instrument = 0 note.program = 0 logging.info("token_sequence_to_note_sequence: Conversion to note sequence complete.") return note_sequence def empty_note_sequence(qpm: float = 120.0, total_time: float = 0.0) -> NoteSequence: note_sequence = NoteSequence() note_sequence.tempos.add().qpm = qpm note_sequence.ticks_per_quarter = STANDARD_PPQ note_sequence.total_time = total_time return note_sequence device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model, tokenizer = get_model_and_tokenizer() def create_seed_string(genre: str = "OTHER", prompt: str = "") -> str: logging.info(f"create_seed_string: Creating seed string. Genre: {genre}, Prompt: '{prompt}'") if prompt: seed_string = f"PIECE_START PROMPT={prompt} GENRE={genre} TRACK_START" elif genre == "RANDOM": seed_string = "PIECE_START" else: seed_string = f"PIECE_START GENRE={genre} TRACK_START" logging.info(f"create_seed_string: Seed string created: '{seed_string}'") return seed_string def get_instruments(text_sequence: str) -> List[str]: instruments = [] parts = text_sequence.split() for part in parts: if part.startswith("INST="): if part[5:] == "DRUMS": instruments.append("Drums") else: index = int(part[5:]) instruments.append(GM_INSTRUMENTS[index]) return instruments def generate_new_instrument(seed: str, temp: float = 0.85, max_tokens=512) -> str: logging.info(f"generate_new_instrument: Starting instrument generation. Seed: '{seed}', Temperature: {temp}, Max Tokens: {max_tokens}") seed_length = len(tokenizer.encode(seed)) input_ids = tokenizer.encode(seed, return_tensors="pt").to(model.device) eos_token_id = tokenizer.encode("TRACK_END")[0] generated_ids = model.generate( input_ids, max_new_tokens=max_tokens, do_sample=True, temperature=temp, eos_token_id=eos_token_id, ) generated_sequence = tokenizer.decode(generated_ids[0]) new_generated_sequence = tokenizer.decode(generated_ids[0][seed_length:]) logging.info(f"generate_new_instrument: Generated sequence: '{new_generated_sequence}'") if "NOTE_ON" in new_generated_sequence: logging.info("generate_new_instrument: New instrument generated successfully.") return generated_sequence else: logging.warning("generate_new_instrument: No NOTE_ON token found in generated sequence after seed. Generation may be incomplete.") return "" def get_outputs_from_string( generated_sequence: str, qpm: int = 120 ) -> Tuple[ndarray, str, Figure, str, str]: logging.info(f"get_outputs_from_string: Starting output generation. QPM: {qpm}") instruments = get_instruments(generated_sequence) instruments_str = "\n".join(f"- {instrument}" for instrument in instruments) note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm) if not note_sequence.notes: logging.warning("get_outputs_from_string: Note sequence is empty, skipping plot.") fig = None else: fig = note_seq.plot_sequence(note_sequence, show_figure=False) synth = note_seq.fluidsynth array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE) int16_data = note_seq.audio_io.float_samples_to_int16(array_of_floats) num_tokens = str(len(generated_sequence.split())) audio = gr.make_waveform((SAMPLE_RATE, int16_data)) note_seq.note_sequence_to_midi_file(note_sequence, "midi_ouput.mid") logging.info("get_outputs_from_string: Output generation complete.") return audio, "midi_ouput.mid", fig, instruments_str, num_tokens def remove_last_instrument( text_sequence: str, qpm: int = 120 ) -> Tuple[ndarray, str, Figure, str, str, str]: logging.info(f"remove_last_instrument: Removing last instrument. QPM: {qpm}") tracks = text_sequence.split("TRACK_START") modified_tracks = tracks[:-1] new_song = "TRACK_START".join(modified_tracks) if len(tracks) == 2: audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song( text_sequence=new_song, qpm=qpm, duration=1 ) elif len(tracks) == 1: audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song( text_sequence="", qpm=qpm, duration=1 ) else: audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string( new_song, qpm ) logging.info("remove_last_instrument: Last instrument removed.") return audio, midi_file, fig, instruments_str, new_song, num_tokens def regenerate_last_instrument( text_sequence: str, qpm: int = 120 ) -> Tuple[ndarray, str, Figure, str, str, str]: logging.info(f"regenerate_last_instrument: Regenerating last instrument. QPM: {qpm}") last_inst_index = text_sequence.rfind("INST=") if last_inst_index == -1: audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song( text_sequence="", qpm=qpm, duration=1 ) else: next_space_index = text_sequence.find(" ", last_inst_index) new_seed = text_sequence[:next_space_index] audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song( text_sequence=new_seed, qpm=qpm, duration=1 ) logging.info("regenerate_last_instrument: Last instrument regenerated.") return audio, midi_file, fig, instruments_str, new_song, num_tokens def change_tempo( text_sequence: str, qpm: int ) -> Tuple[ndarray, str, Figure, str, str, str]: logging.info(f"change_tempo: Changing tempo to {qpm} QPM.") audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string( text_sequence, qpm=qpm ) logging.info(f"change_tempo: Tempo changed to {qpm} QPM.") return audio, midi_file, fig, instruments_str, text_sequence, num_tokens def generate_song( genre: str = "OTHER", temp: float = 0.85, text_sequence: str = "", qpm: int = 120, prompt: str = "", duration: int = 30 ) -> Tuple[ndarray, str, Figure, str, str, str]: logging.info(f"generate_song: Starting song generation. Genre: {genre}, Temperature: {temp}, QPM: {qpm}, Duration: {duration} seconds, Prompt: '{prompt}'") if text_sequence == "": seed_string = create_seed_string(genre, prompt) else: seed_string = text_sequence num_tracks = max(1, int(math.ceil(duration / 17))) generated_sequence = seed_string for _ in range(num_tracks): instrument_sequence = generate_new_instrument(seed=generated_sequence, temp=temp) if instrument_sequence: generated_sequence = instrument_sequence else: logging.warning("generate_song: Instrument generation failed, stopping track generation early.") break audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string( generated_sequence, qpm ) logging.info("generate_song: Song generation complete.") return audio, midi_file, fig, instruments_str, generated_sequence, num_tokens os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" genres = ["ROCK", "POP", "OTHER", "R&B/SOUL", "JAZZ", "ELECTRONIC", "RANDOM"] demo = gr.Blocks() def run(): with demo: gr.DuplicateButton(value="Duplicate Space for private use") with gr.Row(): with gr.Column(): prompt_text = gr.Textbox(lines=2, placeholder="Enter text prompt here...", label="Text Prompt (Optional)") duration_slider = gr.Slider(minimum=1, maximum=1000, step=1, value=30, label="Duration (Seconds)") temp = gr.Slider( minimum=0, maximum=1, step=0.05, value=0.85, label="Temperature" ) genre = gr.Dropdown( choices=genres, value="POP", label="Select Genre" ) with gr.Row(): btn_from_scratch = gr.Button("🧹 Start from scratch") btn_continue = gr.Button("➡️ Continue Generation") btn_remove_last = gr.Button("↩️ Remove last instrument") btn_regenerate_last = gr.Button("🔄 Regenerate last instrument") with gr.Column(): with gr.Group(): audio_output = gr.Video(show_share_button=True) midi_file = gr.File() with gr.Row(): qpm = gr.Slider( minimum=60, maximum=140, step=10, value=120, label="Tempo" ) btn_qpm = gr.Button("Change Tempo") with gr.Row(): with gr.Column(): plot_output = gr.Plot() with gr.Column(): instruments_output = gr.Markdown("# List of generated instruments") with gr.Row(): text_sequence = gr.Text() empty_sequence = gr.Text(visible=False) with gr.Row(): num_tokens = gr.Text(visible=False) btn_from_scratch.click( fn=generate_song, inputs=[genre, temp, empty_sequence, qpm, prompt_text, duration_slider], outputs=[ audio_output, midi_file, plot_output, instruments_output, text_sequence, num_tokens, ], api_name="generate_song_scratch" ) btn_continue.click( fn=generate_song, inputs=[genre, temp, text_sequence, qpm, prompt_text, duration_slider], outputs=[ audio_output, midi_file, plot_output, instruments_output, text_sequence, num_tokens, ], api_name="generate_song_continue" ) btn_remove_last.click( fn=remove_last_instrument, inputs=[text_sequence, qpm], outputs=[ audio_output, midi_file, plot_output, instruments_output, text_sequence, num_tokens, ], api_name="remove_last_instrument" ) btn_regenerate_last.click( fn=regenerate_last_instrument, inputs=[text_sequence, qpm], outputs=[ audio_output, midi_file, plot_output, instruments_output, text_sequence, num_tokens, ], api_name="regenerate_last_instrument" ) btn_qpm.click( fn=change_tempo, inputs=[text_sequence, qpm], outputs=[ audio_output, midi_file, plot_output, instruments_output, text_sequence, num_tokens, ], api_name="change_tempo" ) demo.queue().launch(server_name="0.0.0.0", server_port=7860) if __name__ == "__main__": run()