Hjgugugjhuhjggg's picture
Update main.py
5a91b88 verified
import os
import gradio as gr
import torch
from typing import Tuple, List, Optional
from transformers import AutoTokenizer, AutoModelForCausalLM
import note_seq
from matplotlib.figure import Figure
from numpy import ndarray
from note_seq.protobuf.music_pb2 import NoteSequence
from note_seq.constants import STANDARD_PPQ
import logging
import math
logging.basicConfig(level=logging.INFO)
SAMPLE_RATE = 44100
GM_INSTRUMENTS = [
"Acoustic Grand Piano",
"Bright Acoustic Piano",
"Electric Grand Piano",
"Honky-tonk Piano",
"Electric Piano 1",
"Electric Piano 2",
"Harpsichord",
"Clavi",
"Celesta",
"Glockenspiel",
"Music Box",
"Vibraphone",
"Marimba",
"Xylophone",
"Tubular Bells",
"Dulcimer",
"Drawbar Organ",
"Percussive Organ",
"Rock Organ",
"Church Organ",
"Reed Organ",
"Accordion",
"Harmonica",
"Tango Accordion",
"Acoustic Guitar (nylon)",
"Acoustic Guitar (steel)",
"Electric Guitar (jazz)",
"Electric Guitar (clean)",
"Electric Guitar (muted)",
"Overdriven Guitar",
"Distortion Guitar",
"Guitar Harmonics",
"Acoustic Bass",
"Electric Bass (finger)",
"Electric Bass (pick)",
"Fretless Bass",
"Slap Bass 1",
"Slap Bass 2",
"Synth Bass 1",
"Synth Bass 2",
"Violin",
"Viola",
"Cello",
"Contrabass",
"Tremolo Strings",
"Pizzicato Strings",
"Orchestral Harp",
"Timpani",
"String Ensemble 1",
"String Ensemble 2",
"Synth Strings 1",
"Synth Strings 2",
"Choir Aahs",
"Voice Oohs",
"Synth Choir",
"Orchestra Hit",
"Trumpet",
"Trombone",
"Tuba",
"Muted Trumpet",
"French Horn",
"Brass Section",
"Synth Brass 1",
"Synth Brass 2",
"Soprano Sax",
"Alto Sax",
"Tenor Sax",
"Baritone Sax",
"Oboe",
"English Horn",
"Bassoon",
"Clarinet",
"Piccolo",
"Flute",
"Recorder",
"Pan Flute",
"Blown Bottle",
"Shakuhachi",
"Whistle",
"Ocarina",
"Lead 1 (square)",
"Lead 2 (sawtooth)",
"Lead 3 (calliope)",
"Lead 4 (chiff)",
"Lead 5 (charang)",
"Lead 6 (voice)",
"Lead 7 (fifths)",
"Lead 8 (bass + lead)",
"Pad 1 (new age)",
"Pad 2 (warm)",
"Pad 3 (polysynth)",
"Pad 4 (choir)",
"Pad 5 (bowed)",
"Pad 6 (metallic)",
"Pad 7 (halo)",
"Pad 8 (sweep)",
"FX 1 (rain)",
"FX 2 (soundtrack)",
"FX 3 (crystal)",
"FX 4 (atmosphere)",
"FX 5 (brightness)",
"FX 6 (goblins)",
"FX 7 (echoes)",
"FX 8 (sci-fi)",
"Sitar",
"Banjo",
"Shamisen",
"Koto",
"Kalimba",
"Bagpipe",
"Fiddle",
"Shanai",
"Tinkle Bell",
"Agogo",
"Steel Drums",
"Woodblock",
"Taiko Drum",
"Melodic Tom",
"Synth Drum",
"Reverse Cymbal",
"Guitar Fret Noise",
"Breath Noise",
"Seashore",
"Bird Tweet",
"Telephone Ring",
"Helicopter",
"Applause",
"Gunshot",
]
tokenizer = None
model = None
def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
logging.info("get_model_and_tokenizer: Starting to load model and tokenizer...")
global model, tokenizer
if model is None or tokenizer is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"get_model_and_tokenizer: Using device: {device}")
tokenizer = AutoTokenizer.from_pretrained("juancopi81/lmd_8bars_tokenizer")
model = AutoModelForCausalLM.from_pretrained(
"juancopi81/lmd-8bars-2048-epochs40_v4"
)
model = model.to(device)
logging.info("get_model_and_tokenizer: Model and tokenizer loaded successfully.")
else:
logging.info("get_model_and_tokenizer: Model and tokenizer already loaded.")
return model, tokenizer
def token_sequence_to_note_sequence(
token_sequence: str,
qpm: float = 120.0,
use_program: bool = True,
use_drums: bool = True,
instrument_mapper: Optional[dict] = None,
only_piano: bool = False,
) -> NoteSequence:
logging.info(f"token_sequence_to_note_sequence: Starting conversion. QPM: {qpm}, use_program: {use_program}, use_drums: {use_drums}, only_piano: {only_piano}")
if isinstance(token_sequence, str):
token_sequence = token_sequence.split()
note_sequence = empty_note_sequence(qpm)
note_length_16th = 0.25 * 60 / qpm
bar_length = 4.0 * 60 / qpm
current_program = 1
current_is_drum = False
current_instrument = 0
track_count = 0
for _, token in enumerate(token_sequence):
if token == "PIECE_START":
pass
elif token == "PIECE_END":
break
elif token == "TRACK_START":
current_bar_index = 0
track_count += 1
pass
elif token == "TRACK_END":
pass
elif token == "KEYS_START":
pass
elif token == "KEYS_END":
pass
elif token.startswith("KEY="):
pass
elif token.startswith("INST"):
instrument = token.split("=")[-1]
if instrument != "DRUMS" and use_program:
if instrument_mapper is not None:
if instrument in instrument_mapper:
instrument = instrument_mapper[instrument]
current_program = int(instrument)
current_instrument = track_count
current_is_drum = False
if instrument == "DRUMS" and use_drums:
current_instrument = 0
current_program = 0
current_is_drum = True
elif token == "BAR_START":
current_time = current_bar_index * bar_length
current_notes = {}
elif token == "BAR_END":
current_bar_index += 1
pass
elif token.startswith("NOTE_ON"):
pitch = int(token.split("=")[-1])
note = note_sequence.notes.add()
note.start_time = current_time
note.end_time = current_time + 4 * note_length_16th
note.pitch = pitch
note.instrument = current_instrument
note.program = current_program
note.velocity = 80
note.is_drum = current_is_drum
current_notes[pitch] = note
elif token.startswith("NOTE_OFF"):
pitch = int(token.split("=")[-1])
if pitch in current_notes:
note = current_notes[pitch]
note.end_time = current_time
elif token.startswith("TIME_DELTA"):
delta = float(token.split("=")[-1]) * note_length_16th
current_time += delta
elif token.startswith("DENSITY="):
pass
elif token == "[PAD]":
pass
else:
pass
instruments_drums = []
for note in note_sequence.notes:
pair = [note.program, note.is_drum]
if pair not in instruments_drums:
instruments_drums += [pair]
note.instrument = instruments_drums.index(pair)
if only_piano:
for note in note_sequence.notes:
if not note.is_drum:
note.instrument = 0
note.program = 0
logging.info("token_sequence_to_note_sequence: Conversion to note sequence complete.")
return note_sequence
def empty_note_sequence(qpm: float = 120.0, total_time: float = 0.0) -> NoteSequence:
note_sequence = NoteSequence()
note_sequence.tempos.add().qpm = qpm
note_sequence.ticks_per_quarter = STANDARD_PPQ
note_sequence.total_time = total_time
return note_sequence
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, tokenizer = get_model_and_tokenizer()
def create_seed_string(genre: str = "OTHER", prompt: str = "") -> str:
logging.info(f"create_seed_string: Creating seed string. Genre: {genre}, Prompt: '{prompt}'")
if prompt:
seed_string = f"PIECE_START PROMPT={prompt} GENRE={genre} TRACK_START"
elif genre == "RANDOM":
seed_string = "PIECE_START"
else:
seed_string = f"PIECE_START GENRE={genre} TRACK_START"
logging.info(f"create_seed_string: Seed string created: '{seed_string}'")
return seed_string
def get_instruments(text_sequence: str) -> List[str]:
instruments = []
parts = text_sequence.split()
for part in parts:
if part.startswith("INST="):
if part[5:] == "DRUMS":
instruments.append("Drums")
else:
index = int(part[5:])
instruments.append(GM_INSTRUMENTS[index])
return instruments
def generate_new_instrument(seed: str, temp: float = 0.85, max_tokens=512) -> str:
logging.info(f"generate_new_instrument: Starting instrument generation. Seed: '{seed}', Temperature: {temp}, Max Tokens: {max_tokens}")
seed_length = len(tokenizer.encode(seed))
input_ids = tokenizer.encode(seed, return_tensors="pt").to(model.device)
eos_token_id = tokenizer.encode("TRACK_END")[0]
generated_ids = model.generate(
input_ids,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temp,
eos_token_id=eos_token_id,
)
generated_sequence = tokenizer.decode(generated_ids[0])
new_generated_sequence = tokenizer.decode(generated_ids[0][seed_length:])
logging.info(f"generate_new_instrument: Generated sequence: '{new_generated_sequence}'")
if "NOTE_ON" in new_generated_sequence:
logging.info("generate_new_instrument: New instrument generated successfully.")
return generated_sequence
else:
logging.warning("generate_new_instrument: No NOTE_ON token found in generated sequence after seed. Generation may be incomplete.")
return ""
def get_outputs_from_string(
generated_sequence: str, qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str]:
logging.info(f"get_outputs_from_string: Starting output generation. QPM: {qpm}")
instruments = get_instruments(generated_sequence)
instruments_str = "\n".join(f"- {instrument}" for instrument in instruments)
note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm)
if not note_sequence.notes:
logging.warning("get_outputs_from_string: Note sequence is empty, skipping plot.")
fig = None
else:
fig = note_seq.plot_sequence(note_sequence, show_figure=False)
synth = note_seq.fluidsynth
array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE)
int16_data = note_seq.audio_io.float_samples_to_int16(array_of_floats)
num_tokens = str(len(generated_sequence.split()))
audio = gr.make_waveform((SAMPLE_RATE, int16_data))
note_seq.note_sequence_to_midi_file(note_sequence, "midi_ouput.mid")
logging.info("get_outputs_from_string: Output generation complete.")
return audio, "midi_ouput.mid", fig, instruments_str, num_tokens
def remove_last_instrument(
text_sequence: str, qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str, str]:
logging.info(f"remove_last_instrument: Removing last instrument. QPM: {qpm}")
tracks = text_sequence.split("TRACK_START")
modified_tracks = tracks[:-1]
new_song = "TRACK_START".join(modified_tracks)
if len(tracks) == 2:
audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
text_sequence=new_song, qpm=qpm, duration=1
)
elif len(tracks) == 1:
audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
text_sequence="", qpm=qpm, duration=1
)
else:
audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
new_song, qpm
)
logging.info("remove_last_instrument: Last instrument removed.")
return audio, midi_file, fig, instruments_str, new_song, num_tokens
def regenerate_last_instrument(
text_sequence: str, qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str, str]:
logging.info(f"regenerate_last_instrument: Regenerating last instrument. QPM: {qpm}")
last_inst_index = text_sequence.rfind("INST=")
if last_inst_index == -1:
audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
text_sequence="", qpm=qpm, duration=1
)
else:
next_space_index = text_sequence.find(" ", last_inst_index)
new_seed = text_sequence[:next_space_index]
audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
text_sequence=new_seed, qpm=qpm, duration=1
)
logging.info("regenerate_last_instrument: Last instrument regenerated.")
return audio, midi_file, fig, instruments_str, new_song, num_tokens
def change_tempo(
text_sequence: str, qpm: int
) -> Tuple[ndarray, str, Figure, str, str, str]:
logging.info(f"change_tempo: Changing tempo to {qpm} QPM.")
audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
text_sequence, qpm=qpm
)
logging.info(f"change_tempo: Tempo changed to {qpm} QPM.")
return audio, midi_file, fig, instruments_str, text_sequence, num_tokens
def generate_song(
genre: str = "OTHER",
temp: float = 0.85,
text_sequence: str = "",
qpm: int = 120,
prompt: str = "",
duration: int = 30
) -> Tuple[ndarray, str, Figure, str, str, str]:
logging.info(f"generate_song: Starting song generation. Genre: {genre}, Temperature: {temp}, QPM: {qpm}, Duration: {duration} seconds, Prompt: '{prompt}'")
if text_sequence == "":
seed_string = create_seed_string(genre, prompt)
else:
seed_string = text_sequence
num_tracks = max(1, int(math.ceil(duration / 17)))
generated_sequence = seed_string
for _ in range(num_tracks):
instrument_sequence = generate_new_instrument(seed=generated_sequence, temp=temp)
if instrument_sequence:
generated_sequence = instrument_sequence
else:
logging.warning("generate_song: Instrument generation failed, stopping track generation early.")
break
audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
generated_sequence, qpm
)
logging.info("generate_song: Song generation complete.")
return audio, midi_file, fig, instruments_str, generated_sequence, num_tokens
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
genres = ["ROCK", "POP", "OTHER", "R&B/SOUL", "JAZZ", "ELECTRONIC", "RANDOM"]
demo = gr.Blocks()
def run():
with demo:
gr.DuplicateButton(value="Duplicate Space for private use")
with gr.Row():
with gr.Column():
prompt_text = gr.Textbox(lines=2, placeholder="Enter text prompt here...", label="Text Prompt (Optional)")
duration_slider = gr.Slider(minimum=1, maximum=1000, step=1, value=30, label="Duration (Seconds)")
temp = gr.Slider(
minimum=0, maximum=1, step=0.05, value=0.85, label="Temperature"
)
genre = gr.Dropdown(
choices=genres, value="POP", label="Select Genre"
)
with gr.Row():
btn_from_scratch = gr.Button("🧹 Start from scratch")
btn_continue = gr.Button("➡️ Continue Generation")
btn_remove_last = gr.Button("↩️ Remove last instrument")
btn_regenerate_last = gr.Button("🔄 Regenerate last instrument")
with gr.Column():
with gr.Group():
audio_output = gr.Video(show_share_button=True)
midi_file = gr.File()
with gr.Row():
qpm = gr.Slider(
minimum=60, maximum=140, step=10, value=120, label="Tempo"
)
btn_qpm = gr.Button("Change Tempo")
with gr.Row():
with gr.Column():
plot_output = gr.Plot()
with gr.Column():
instruments_output = gr.Markdown("# List of generated instruments")
with gr.Row():
text_sequence = gr.Text()
empty_sequence = gr.Text(visible=False)
with gr.Row():
num_tokens = gr.Text(visible=False)
btn_from_scratch.click(
fn=generate_song,
inputs=[genre, temp, empty_sequence, qpm, prompt_text, duration_slider],
outputs=[
audio_output,
midi_file,
plot_output,
instruments_output,
text_sequence,
num_tokens,
],
api_name="generate_song_scratch"
)
btn_continue.click(
fn=generate_song,
inputs=[genre, temp, text_sequence, qpm, prompt_text, duration_slider],
outputs=[
audio_output,
midi_file,
plot_output,
instruments_output,
text_sequence,
num_tokens,
],
api_name="generate_song_continue"
)
btn_remove_last.click(
fn=remove_last_instrument,
inputs=[text_sequence, qpm],
outputs=[
audio_output,
midi_file,
plot_output,
instruments_output,
text_sequence,
num_tokens,
],
api_name="remove_last_instrument"
)
btn_regenerate_last.click(
fn=regenerate_last_instrument,
inputs=[text_sequence, qpm],
outputs=[
audio_output,
midi_file,
plot_output,
instruments_output,
text_sequence,
num_tokens,
],
api_name="regenerate_last_instrument"
)
btn_qpm.click(
fn=change_tempo,
inputs=[text_sequence, qpm],
outputs=[
audio_output,
midi_file,
plot_output,
instruments_output,
text_sequence,
num_tokens,
],
api_name="change_tempo"
)
demo.queue().launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
run()