Update main.py
Browse files
main.py
CHANGED
@@ -9,6 +9,7 @@ from numpy import ndarray
|
|
9 |
from note_seq.protobuf.music_pb2 import NoteSequence
|
10 |
from note_seq.constants import STANDARD_PPQ
|
11 |
import logging
|
|
|
12 |
|
13 |
logging.basicConfig(level=logging.INFO)
|
14 |
|
@@ -145,6 +146,8 @@ GM_INSTRUMENTS = [
|
|
145 |
]
|
146 |
tokenizer = None
|
147 |
model = None
|
|
|
|
|
148 |
def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
|
149 |
logging.info("get_model_and_tokenizer: Starting to load model and tokenizer...")
|
150 |
global model, tokenizer
|
@@ -311,9 +314,9 @@ def get_outputs_from_string(
|
|
311 |
instruments_str = "\n".join(f"- {instrument}" for instrument in instruments)
|
312 |
note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm)
|
313 |
|
314 |
-
if not note_sequence.notes:
|
315 |
logging.warning("get_outputs_from_string: Note sequence is empty, skipping plot.")
|
316 |
-
fig = None
|
317 |
else:
|
318 |
fig = note_seq.plot_sequence(note_sequence, show_figure=False)
|
319 |
|
@@ -379,16 +382,18 @@ def generate_song(
|
|
379 |
text_sequence: str = "",
|
380 |
qpm: int = 120,
|
381 |
prompt: str = "",
|
382 |
-
duration: int =
|
383 |
) -> Tuple[ndarray, str, Figure, str, str, str]:
|
384 |
-
logging.info(f"generate_song: Starting song generation. Genre: {genre}, Temperature: {temp}, QPM: {qpm}, Duration: {duration}, Prompt: '{prompt}'")
|
385 |
if text_sequence == "":
|
386 |
seed_string = create_seed_string(genre, prompt)
|
387 |
else:
|
388 |
seed_string = text_sequence
|
389 |
|
|
|
|
|
390 |
generated_sequence = seed_string
|
391 |
-
for _ in range(
|
392 |
instrument_sequence = generate_new_instrument(seed=generated_sequence, temp=temp)
|
393 |
if instrument_sequence:
|
394 |
generated_sequence = instrument_sequence
|
@@ -410,7 +415,7 @@ def run():
|
|
410 |
with gr.Row():
|
411 |
with gr.Column():
|
412 |
prompt_text = gr.Textbox(lines=2, placeholder="Enter text prompt here...", label="Text Prompt (Optional)")
|
413 |
-
duration_slider = gr.Slider(minimum=1, maximum=
|
414 |
temp = gr.Slider(
|
415 |
minimum=0, maximum=1, step=0.05, value=0.85, label="Temperature"
|
416 |
)
|
|
|
9 |
from note_seq.protobuf.music_pb2 import NoteSequence
|
10 |
from note_seq.constants import STANDARD_PPQ
|
11 |
import logging
|
12 |
+
import math
|
13 |
|
14 |
logging.basicConfig(level=logging.INFO)
|
15 |
|
|
|
146 |
]
|
147 |
tokenizer = None
|
148 |
model = None
|
149 |
+
|
150 |
+
|
151 |
def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
|
152 |
logging.info("get_model_and_tokenizer: Starting to load model and tokenizer...")
|
153 |
global model, tokenizer
|
|
|
314 |
instruments_str = "\n".join(f"- {instrument}" for instrument in instruments)
|
315 |
note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm)
|
316 |
|
317 |
+
if not note_sequence.notes:
|
318 |
logging.warning("get_outputs_from_string: Note sequence is empty, skipping plot.")
|
319 |
+
fig = None
|
320 |
else:
|
321 |
fig = note_seq.plot_sequence(note_sequence, show_figure=False)
|
322 |
|
|
|
382 |
text_sequence: str = "",
|
383 |
qpm: int = 120,
|
384 |
prompt: str = "",
|
385 |
+
duration: int = 30
|
386 |
) -> Tuple[ndarray, str, Figure, str, str, str]:
|
387 |
+
logging.info(f"generate_song: Starting song generation. Genre: {genre}, Temperature: {temp}, QPM: {qpm}, Duration: {duration} seconds, Prompt: '{prompt}'")
|
388 |
if text_sequence == "":
|
389 |
seed_string = create_seed_string(genre, prompt)
|
390 |
else:
|
391 |
seed_string = text_sequence
|
392 |
|
393 |
+
num_tracks = max(1, int(math.ceil(duration / 17)))
|
394 |
+
|
395 |
generated_sequence = seed_string
|
396 |
+
for _ in range(num_tracks):
|
397 |
instrument_sequence = generate_new_instrument(seed=generated_sequence, temp=temp)
|
398 |
if instrument_sequence:
|
399 |
generated_sequence = instrument_sequence
|
|
|
415 |
with gr.Row():
|
416 |
with gr.Column():
|
417 |
prompt_text = gr.Textbox(lines=2, placeholder="Enter text prompt here...", label="Text Prompt (Optional)")
|
418 |
+
duration_slider = gr.Slider(minimum=1, maximum=1000, step=1, value=30, label="Duration (Seconds)")
|
419 |
temp = gr.Slider(
|
420 |
minimum=0, maximum=1, step=0.05, value=0.85, label="Temperature"
|
421 |
)
|