Update main.py
Browse files
main.py
CHANGED
@@ -263,7 +263,7 @@ model, tokenizer = get_model_and_tokenizer()
|
|
263 |
def create_seed_string(genre: str = "OTHER", prompt: str = "") -> str:
|
264 |
logging.info(f"create_seed_string: Creating seed string. Genre: {genre}, Prompt: '{prompt}'")
|
265 |
if prompt:
|
266 |
-
seed_string = f"PIECE_START PROMPT={prompt} GENRE={genre} TRACK_START"
|
267 |
elif genre == "RANDOM":
|
268 |
seed_string = "PIECE_START"
|
269 |
else:
|
@@ -281,7 +281,7 @@ def get_instruments(text_sequence: str) -> List[str]:
|
|
281 |
index = int(part[5:])
|
282 |
instruments.append(GM_INSTRUMENTS[index])
|
283 |
return instruments
|
284 |
-
def generate_new_instrument(seed: str, temp: float = 0.
|
285 |
logging.info(f"generate_new_instrument: Starting instrument generation. Seed: '{seed}', Temperature: {temp}, Max Tokens: {max_tokens}")
|
286 |
seed_length = len(tokenizer.encode(seed))
|
287 |
input_ids = tokenizer.encode(seed, return_tensors="pt").to(model.device)
|
@@ -295,7 +295,7 @@ def generate_new_instrument(seed: str, temp: float = 0.75, max_tokens=204) -> st
|
|
295 |
)
|
296 |
generated_sequence = tokenizer.decode(generated_ids[0])
|
297 |
new_generated_sequence = tokenizer.decode(generated_ids[0][seed_length:])
|
298 |
-
logging.info(f"generate_new_instrument: Generated sequence: '{new_generated_sequence}'")
|
299 |
if "NOTE_ON" in new_generated_sequence:
|
300 |
logging.info("generate_new_instrument: New instrument generated successfully.")
|
301 |
return generated_sequence
|
@@ -310,15 +310,22 @@ def get_outputs_from_string(
|
|
310 |
instruments = get_instruments(generated_sequence)
|
311 |
instruments_str = "\n".join(f"- {instrument}" for instrument in instruments)
|
312 |
note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
synth = note_seq.fluidsynth
|
314 |
array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE)
|
315 |
int16_data = note_seq.audio_io.float_samples_to_int16(array_of_floats)
|
316 |
-
fig = note_seq.plot_sequence(note_sequence, show_figure=False)
|
317 |
num_tokens = str(len(generated_sequence.split()))
|
318 |
audio = gr.make_waveform((SAMPLE_RATE, int16_data))
|
319 |
note_seq.note_sequence_to_midi_file(note_sequence, "midi_ouput.mid")
|
320 |
logging.info("get_outputs_from_string: Output generation complete.")
|
321 |
return audio, "midi_ouput.mid", fig, instruments_str, num_tokens
|
|
|
322 |
def remove_last_instrument(
|
323 |
text_sequence: str, qpm: int = 120
|
324 |
) -> Tuple[ndarray, str, Figure, str, str, str]:
|
@@ -368,7 +375,7 @@ def change_tempo(
|
|
368 |
return audio, midi_file, fig, instruments_str, text_sequence, num_tokens
|
369 |
def generate_song(
|
370 |
genre: str = "OTHER",
|
371 |
-
temp: float = 0.85,
|
372 |
text_sequence: str = "",
|
373 |
qpm: int = 120,
|
374 |
prompt: str = "",
|
|
|
263 |
def create_seed_string(genre: str = "OTHER", prompt: str = "") -> str:
|
264 |
logging.info(f"create_seed_string: Creating seed string. Genre: {genre}, Prompt: '{prompt}'")
|
265 |
if prompt:
|
266 |
+
seed_string = f"PIECE_START PROMPT={prompt} GENRE={genre} TRACK_START"
|
267 |
elif genre == "RANDOM":
|
268 |
seed_string = "PIECE_START"
|
269 |
else:
|
|
|
281 |
index = int(part[5:])
|
282 |
instruments.append(GM_INSTRUMENTS[index])
|
283 |
return instruments
|
284 |
+
def generate_new_instrument(seed: str, temp: float = 0.85, max_tokens=512) -> str:
|
285 |
logging.info(f"generate_new_instrument: Starting instrument generation. Seed: '{seed}', Temperature: {temp}, Max Tokens: {max_tokens}")
|
286 |
seed_length = len(tokenizer.encode(seed))
|
287 |
input_ids = tokenizer.encode(seed, return_tensors="pt").to(model.device)
|
|
|
295 |
)
|
296 |
generated_sequence = tokenizer.decode(generated_ids[0])
|
297 |
new_generated_sequence = tokenizer.decode(generated_ids[0][seed_length:])
|
298 |
+
logging.info(f"generate_new_instrument: Generated sequence: '{new_generated_sequence}'")
|
299 |
if "NOTE_ON" in new_generated_sequence:
|
300 |
logging.info("generate_new_instrument: New instrument generated successfully.")
|
301 |
return generated_sequence
|
|
|
310 |
instruments = get_instruments(generated_sequence)
|
311 |
instruments_str = "\n".join(f"- {instrument}" for instrument in instruments)
|
312 |
note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm)
|
313 |
+
|
314 |
+
if not note_sequence.notes: # Check if note_sequence is empty
|
315 |
+
logging.warning("get_outputs_from_string: Note sequence is empty, skipping plot.")
|
316 |
+
fig = None # Handle case where fig is None
|
317 |
+
else:
|
318 |
+
fig = note_seq.plot_sequence(note_sequence, show_figure=False)
|
319 |
+
|
320 |
synth = note_seq.fluidsynth
|
321 |
array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE)
|
322 |
int16_data = note_seq.audio_io.float_samples_to_int16(array_of_floats)
|
|
|
323 |
num_tokens = str(len(generated_sequence.split()))
|
324 |
audio = gr.make_waveform((SAMPLE_RATE, int16_data))
|
325 |
note_seq.note_sequence_to_midi_file(note_sequence, "midi_ouput.mid")
|
326 |
logging.info("get_outputs_from_string: Output generation complete.")
|
327 |
return audio, "midi_ouput.mid", fig, instruments_str, num_tokens
|
328 |
+
|
329 |
def remove_last_instrument(
|
330 |
text_sequence: str, qpm: int = 120
|
331 |
) -> Tuple[ndarray, str, Figure, str, str, str]:
|
|
|
375 |
return audio, midi_file, fig, instruments_str, text_sequence, num_tokens
|
376 |
def generate_song(
|
377 |
genre: str = "OTHER",
|
378 |
+
temp: float = 0.85,
|
379 |
text_sequence: str = "",
|
380 |
qpm: int = 120,
|
381 |
prompt: str = "",
|