Spaces:
Sleeping
Sleeping
Commit
·
b7070f2
1
Parent(s):
db8b2d5
Refactor inference functions to accept DEVICE and MODEL parameters for TC5, TC6, and TC7; update model loading to use GPU if available.
Browse files
app.py
CHANGED
|
@@ -11,34 +11,43 @@ from tc7 import infer as tc7infer
|
|
| 11 |
from gradio_client import Client, handle_file
|
| 12 |
import tempfile
|
| 13 |
|
| 14 |
-
|
| 15 |
|
| 16 |
# Load model once
|
| 17 |
tc5 = TaikoConformer5.from_pretrained("JacobLinCool/taiko-conformer-5")
|
| 18 |
-
tc5.to(
|
| 19 |
tc5.eval()
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# Load TC6 model
|
| 22 |
tc6 = TaikoConformer6.from_pretrained("JacobLinCool/taiko-conformer-6")
|
| 23 |
-
tc6.to(
|
| 24 |
tc6.eval()
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# Load TC7 model
|
| 27 |
tc7 = TaikoConformer7.from_pretrained("JacobLinCool/taiko-conformer-7")
|
| 28 |
-
tc7.to(
|
| 29 |
tc7.eval()
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
synthesizer = Client("ryanlinjui/taiko-music-generator")
|
| 32 |
|
| 33 |
|
| 34 |
-
def infer_tc5(audio, nps, bpm, offset):
|
| 35 |
audio_path = audio
|
| 36 |
filename = audio_path.split("/")[-1]
|
| 37 |
# Preprocess
|
| 38 |
mel_input, nps_input = tc5infer.preprocess_audio(audio_path, nps)
|
| 39 |
# Inference
|
| 40 |
don_energy, ka_energy, drumroll_energy = tc5infer.run_inference(
|
| 41 |
-
|
| 42 |
)
|
| 43 |
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
| 44 |
onsets = tc5infer.decode_onsets(
|
|
@@ -91,7 +100,7 @@ def infer_tc5(audio, nps, bpm, offset):
|
|
| 91 |
return oni_audio, plot, tja_content
|
| 92 |
|
| 93 |
|
| 94 |
-
def infer_tc6(audio, nps, bpm, offset, difficulty, level):
|
| 95 |
audio_path = audio
|
| 96 |
filename = audio_path.split("/")[-1]
|
| 97 |
# Preprocess
|
|
@@ -101,7 +110,7 @@ def infer_tc6(audio, nps, bpm, offset, difficulty, level):
|
|
| 101 |
level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE)
|
| 102 |
# Inference
|
| 103 |
don_energy, ka_energy, drumroll_energy = tc6infer.run_inference(
|
| 104 |
-
|
| 105 |
)
|
| 106 |
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
| 107 |
onsets = tc6infer.decode_onsets(
|
|
@@ -154,7 +163,7 @@ def infer_tc6(audio, nps, bpm, offset, difficulty, level):
|
|
| 154 |
return oni_audio, plot, tja_content
|
| 155 |
|
| 156 |
|
| 157 |
-
def infer_tc7(audio, nps, bpm, offset, difficulty, level):
|
| 158 |
audio_path = audio
|
| 159 |
filename = audio_path.split("/")[-1]
|
| 160 |
# Preprocess
|
|
@@ -164,7 +173,7 @@ def infer_tc7(audio, nps, bpm, offset, difficulty, level):
|
|
| 164 |
level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE)
|
| 165 |
# Inference
|
| 166 |
don_energy, ka_energy, drumroll_energy = tc7infer.run_inference(
|
| 167 |
-
|
| 168 |
)
|
| 169 |
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
| 170 |
onsets = tc7infer.decode_onsets(
|
|
@@ -220,20 +229,21 @@ def infer_tc7(audio, nps, bpm, offset, difficulty, level):
|
|
| 220 |
@spaces.GPU
|
| 221 |
def run_inference_gpu(audio, model_choice, nps, bpm, offset, difficulty, level):
|
| 222 |
if model_choice == "TC5":
|
| 223 |
-
return infer_tc5(audio, nps, bpm, offset)
|
| 224 |
elif model_choice == "TC6":
|
| 225 |
-
return infer_tc6(audio, nps, bpm, offset, difficulty, level)
|
| 226 |
else: # TC7
|
| 227 |
-
return infer_tc7(audio, nps, bpm, offset, difficulty, level)
|
| 228 |
|
| 229 |
|
| 230 |
def run_inference_cpu(audio, model_choice, nps, bpm, offset, difficulty, level):
|
|
|
|
| 231 |
if model_choice == "TC5":
|
| 232 |
-
return infer_tc5(audio, nps, bpm, offset)
|
| 233 |
elif model_choice == "TC6":
|
| 234 |
-
return infer_tc6(audio, nps, bpm, offset, difficulty, level)
|
| 235 |
else: # TC7
|
| 236 |
-
return infer_tc7(audio, nps, bpm, offset, difficulty, level)
|
| 237 |
|
| 238 |
|
| 239 |
def run_inference(with_gpu, audio, model_choice, nps, bpm, offset, difficulty, level):
|
|
@@ -330,7 +340,16 @@ with gr.Blocks() as demo:
|
|
| 330 |
|
| 331 |
run_btn.click(
|
| 332 |
run_inference,
|
| 333 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
outputs=[audio_output, plot_output, tja_output],
|
| 335 |
)
|
| 336 |
|
|
|
|
| 11 |
from gradio_client import Client, handle_file
|
| 12 |
import tempfile
|
| 13 |
|
| 14 |
+
GPU_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
|
| 16 |
# Load model once
|
| 17 |
tc5 = TaikoConformer5.from_pretrained("JacobLinCool/taiko-conformer-5")
|
| 18 |
+
tc5.to(GPU_DEVICE)
|
| 19 |
tc5.eval()
|
| 20 |
+
tc5_cpu = TaikoConformer5.from_pretrained("JacobLinCool/taiko-conformer-5")
|
| 21 |
+
tc5_cpu.to("cpu")
|
| 22 |
+
tc5_cpu.eval()
|
| 23 |
|
| 24 |
# Load TC6 model
|
| 25 |
tc6 = TaikoConformer6.from_pretrained("JacobLinCool/taiko-conformer-6")
|
| 26 |
+
tc6.to(GPU_DEVICE)
|
| 27 |
tc6.eval()
|
| 28 |
+
tc6_cpu = TaikoConformer6.from_pretrained("JacobLinCool/taiko-conformer-6")
|
| 29 |
+
tc6_cpu.to("cpu")
|
| 30 |
+
tc6_cpu.eval()
|
| 31 |
|
| 32 |
# Load TC7 model
|
| 33 |
tc7 = TaikoConformer7.from_pretrained("JacobLinCool/taiko-conformer-7")
|
| 34 |
+
tc7.to(GPU_DEVICE)
|
| 35 |
tc7.eval()
|
| 36 |
+
tc7_cpu = TaikoConformer7.from_pretrained("JacobLinCool/taiko-conformer-7")
|
| 37 |
+
tc7_cpu.to("cpu")
|
| 38 |
+
tc7_cpu.eval()
|
| 39 |
|
| 40 |
synthesizer = Client("ryanlinjui/taiko-music-generator")
|
| 41 |
|
| 42 |
|
| 43 |
+
def infer_tc5(audio, nps, bpm, offset, DEVICE, MODEL):
|
| 44 |
audio_path = audio
|
| 45 |
filename = audio_path.split("/")[-1]
|
| 46 |
# Preprocess
|
| 47 |
mel_input, nps_input = tc5infer.preprocess_audio(audio_path, nps)
|
| 48 |
# Inference
|
| 49 |
don_energy, ka_energy, drumroll_energy = tc5infer.run_inference(
|
| 50 |
+
MODEL, mel_input, nps_input, DEVICE
|
| 51 |
)
|
| 52 |
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
| 53 |
onsets = tc5infer.decode_onsets(
|
|
|
|
| 100 |
return oni_audio, plot, tja_content
|
| 101 |
|
| 102 |
|
| 103 |
+
def infer_tc6(audio, nps, bpm, offset, difficulty, level, DEVICE, MODEL):
|
| 104 |
audio_path = audio
|
| 105 |
filename = audio_path.split("/")[-1]
|
| 106 |
# Preprocess
|
|
|
|
| 110 |
level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE)
|
| 111 |
# Inference
|
| 112 |
don_energy, ka_energy, drumroll_energy = tc6infer.run_inference(
|
| 113 |
+
MODEL, mel_input, nps_input, difficulty_input, level_input, DEVICE
|
| 114 |
)
|
| 115 |
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
| 116 |
onsets = tc6infer.decode_onsets(
|
|
|
|
| 163 |
return oni_audio, plot, tja_content
|
| 164 |
|
| 165 |
|
| 166 |
+
def infer_tc7(audio, nps, bpm, offset, difficulty, level, DEVICE, MODEL):
|
| 167 |
audio_path = audio
|
| 168 |
filename = audio_path.split("/")[-1]
|
| 169 |
# Preprocess
|
|
|
|
| 173 |
level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE)
|
| 174 |
# Inference
|
| 175 |
don_energy, ka_energy, drumroll_energy = tc7infer.run_inference(
|
| 176 |
+
MODEL, mel_input, nps_input, difficulty_input, level_input, DEVICE
|
| 177 |
)
|
| 178 |
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
| 179 |
onsets = tc7infer.decode_onsets(
|
|
|
|
| 229 |
@spaces.GPU
|
| 230 |
def run_inference_gpu(audio, model_choice, nps, bpm, offset, difficulty, level):
|
| 231 |
if model_choice == "TC5":
|
| 232 |
+
return infer_tc5(audio, nps, bpm, offset, GPU_DEVICE, tc5)
|
| 233 |
elif model_choice == "TC6":
|
| 234 |
+
return infer_tc6(audio, nps, bpm, offset, difficulty, level, GPU_DEVICE, tc6)
|
| 235 |
else: # TC7
|
| 236 |
+
return infer_tc7(audio, nps, bpm, offset, difficulty, level, GPU_DEVICE, tc7)
|
| 237 |
|
| 238 |
|
| 239 |
def run_inference_cpu(audio, model_choice, nps, bpm, offset, difficulty, level):
|
| 240 |
+
DEVICE = torch.device("cpu")
|
| 241 |
if model_choice == "TC5":
|
| 242 |
+
return infer_tc5(audio, nps, bpm, offset, DEVICE, tc5_cpu)
|
| 243 |
elif model_choice == "TC6":
|
| 244 |
+
return infer_tc6(audio, nps, bpm, offset, difficulty, level, DEVICE, tc6_cpu)
|
| 245 |
else: # TC7
|
| 246 |
+
return infer_tc7(audio, nps, bpm, offset, difficulty, level, DEVICE, tc7_cpu)
|
| 247 |
|
| 248 |
|
| 249 |
def run_inference(with_gpu, audio, model_choice, nps, bpm, offset, difficulty, level):
|
|
|
|
| 340 |
|
| 341 |
run_btn.click(
|
| 342 |
run_inference,
|
| 343 |
+
inputs=[
|
| 344 |
+
with_gpu,
|
| 345 |
+
audio_input,
|
| 346 |
+
model_choice,
|
| 347 |
+
nps,
|
| 348 |
+
bpm,
|
| 349 |
+
offset,
|
| 350 |
+
difficulty,
|
| 351 |
+
level,
|
| 352 |
+
],
|
| 353 |
outputs=[audio_output, plot_output, tja_output],
|
| 354 |
)
|
| 355 |
|