AudioLDM / scripts /text2sound.py
inoculatemedia's picture
Upload 93 files
8815caa verified
import os
from audioldm import text_to_audio, build_model, save_wave
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"-t",
"--text",
type=str,
required=False,
default="A hammer is hitting a wooden surface",
help="Text prompt to the model for audio generation",
)
parser.add_argument(
"-s",
"--save_path",
type=str,
required=False,
help="The path to save model output",
default="./output",
)
parser.add_argument(
"-ckpt",
"--ckpt_path",
type=str,
required=False,
help="The path to the pretrained .ckpt model",
default="./ckpt/audioldm-s-full.ckpt",
)
parser.add_argument(
"-b",
"--batchsize",
type=int,
required=False,
default=1,
help="Generate how many samples at the same time",
)
parser.add_argument(
"-gs",
"--guidance_scale",
type=float,
required=False,
default=2.5,
help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
)
parser.add_argument(
"-dur",
"--duration",
type=float,
required=False,
default=10.0,
help="The duration of the samples",
)
parser.add_argument(
"-n",
"--n_candidate_gen_per_text",
type=int,
required=False,
default=3,
help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
)
parser.add_argument(
"--seed",
type=int,
required=False,
default=42,
help="Change this value (any integer number) will lead to a different generation result.",
)
args = parser.parse_args()
assert args.duration % 2.5 == 0, "Duration must be a multiple of 2.5"
save_path = args.save_path
text = args.text
random_seed = args.seed
duration = args.duration
guidance_scale = args.guidance_scale
n_candidate_gen_per_text = args.n_candidate_gen_per_text
os.makedirs(save_path, exist_ok=True)
audioldm = build_model(ckpt_path=args.ckpt_path)
waveform = text_to_audio(
audioldm,
text,
seed=random_seed,
duration=duration,
guidance_scale=guidance_scale,
n_candidate_gen_per_text=n_candidate_gen_per_text,
batchsize=args.batchsize,
)
save_wave(waveform, save_path, name=text)