Update main.py
Browse files
main.py
CHANGED
|
@@ -9,6 +9,7 @@ from numpy import ndarray
|
|
| 9 |
from note_seq.protobuf.music_pb2 import NoteSequence
|
| 10 |
from note_seq.constants import STANDARD_PPQ
|
| 11 |
import logging
|
|
|
|
| 12 |
|
| 13 |
logging.basicConfig(level=logging.INFO)
|
| 14 |
|
|
@@ -145,6 +146,8 @@ GM_INSTRUMENTS = [
|
|
| 145 |
]
|
| 146 |
tokenizer = None
|
| 147 |
model = None
|
|
|
|
|
|
|
| 148 |
def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
|
| 149 |
logging.info("get_model_and_tokenizer: Starting to load model and tokenizer...")
|
| 150 |
global model, tokenizer
|
|
@@ -311,9 +314,9 @@ def get_outputs_from_string(
|
|
| 311 |
instruments_str = "\n".join(f"- {instrument}" for instrument in instruments)
|
| 312 |
note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm)
|
| 313 |
|
| 314 |
-
if not note_sequence.notes:
|
| 315 |
logging.warning("get_outputs_from_string: Note sequence is empty, skipping plot.")
|
| 316 |
-
fig = None
|
| 317 |
else:
|
| 318 |
fig = note_seq.plot_sequence(note_sequence, show_figure=False)
|
| 319 |
|
|
@@ -379,16 +382,18 @@ def generate_song(
|
|
| 379 |
text_sequence: str = "",
|
| 380 |
qpm: int = 120,
|
| 381 |
prompt: str = "",
|
| 382 |
-
duration: int =
|
| 383 |
) -> Tuple[ndarray, str, Figure, str, str, str]:
|
| 384 |
-
logging.info(f"generate_song: Starting song generation. Genre: {genre}, Temperature: {temp}, QPM: {qpm}, Duration: {duration}, Prompt: '{prompt}'")
|
| 385 |
if text_sequence == "":
|
| 386 |
seed_string = create_seed_string(genre, prompt)
|
| 387 |
else:
|
| 388 |
seed_string = text_sequence
|
| 389 |
|
|
|
|
|
|
|
| 390 |
generated_sequence = seed_string
|
| 391 |
-
for _ in range(
|
| 392 |
instrument_sequence = generate_new_instrument(seed=generated_sequence, temp=temp)
|
| 393 |
if instrument_sequence:
|
| 394 |
generated_sequence = instrument_sequence
|
|
@@ -410,7 +415,7 @@ def run():
|
|
| 410 |
with gr.Row():
|
| 411 |
with gr.Column():
|
| 412 |
prompt_text = gr.Textbox(lines=2, placeholder="Enter text prompt here...", label="Text Prompt (Optional)")
|
| 413 |
-
duration_slider = gr.Slider(minimum=1, maximum=
|
| 414 |
temp = gr.Slider(
|
| 415 |
minimum=0, maximum=1, step=0.05, value=0.85, label="Temperature"
|
| 416 |
)
|
|
|
|
| 9 |
from note_seq.protobuf.music_pb2 import NoteSequence
|
| 10 |
from note_seq.constants import STANDARD_PPQ
|
| 11 |
import logging
|
| 12 |
+
import math
|
| 13 |
|
| 14 |
logging.basicConfig(level=logging.INFO)
|
| 15 |
|
|
|
|
| 146 |
]
|
| 147 |
tokenizer = None
|
| 148 |
model = None
|
| 149 |
+
|
| 150 |
+
|
| 151 |
def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
|
| 152 |
logging.info("get_model_and_tokenizer: Starting to load model and tokenizer...")
|
| 153 |
global model, tokenizer
|
|
|
|
| 314 |
instruments_str = "\n".join(f"- {instrument}" for instrument in instruments)
|
| 315 |
note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm)
|
| 316 |
|
| 317 |
+
if not note_sequence.notes:
|
| 318 |
logging.warning("get_outputs_from_string: Note sequence is empty, skipping plot.")
|
| 319 |
+
fig = None
|
| 320 |
else:
|
| 321 |
fig = note_seq.plot_sequence(note_sequence, show_figure=False)
|
| 322 |
|
|
|
|
| 382 |
text_sequence: str = "",
|
| 383 |
qpm: int = 120,
|
| 384 |
prompt: str = "",
|
| 385 |
+
duration: int = 30
|
| 386 |
) -> Tuple[ndarray, str, Figure, str, str, str]:
|
| 387 |
+
logging.info(f"generate_song: Starting song generation. Genre: {genre}, Temperature: {temp}, QPM: {qpm}, Duration: {duration} seconds, Prompt: '{prompt}'")
|
| 388 |
if text_sequence == "":
|
| 389 |
seed_string = create_seed_string(genre, prompt)
|
| 390 |
else:
|
| 391 |
seed_string = text_sequence
|
| 392 |
|
| 393 |
+
num_tracks = max(1, int(math.ceil(duration / 17)))
|
| 394 |
+
|
| 395 |
generated_sequence = seed_string
|
| 396 |
+
for _ in range(num_tracks):
|
| 397 |
instrument_sequence = generate_new_instrument(seed=generated_sequence, temp=temp)
|
| 398 |
if instrument_sequence:
|
| 399 |
generated_sequence = instrument_sequence
|
|
|
|
| 415 |
with gr.Row():
|
| 416 |
with gr.Column():
|
| 417 |
prompt_text = gr.Textbox(lines=2, placeholder="Enter text prompt here...", label="Text Prompt (Optional)")
|
| 418 |
+
duration_slider = gr.Slider(minimum=1, maximum=1000, step=1, value=30, label="Duration (Seconds)")
|
| 419 |
temp = gr.Slider(
|
| 420 |
minimum=0, maximum=1, step=0.05, value=0.85, label="Temperature"
|
| 421 |
)
|