Spaces:
Runtime error
Runtime error
Update FantasyTalking/infer.py
Browse files- FantasyTalking/infer.py +112 -23
FantasyTalking/infer.py
CHANGED
@@ -18,25 +18,112 @@ from FantasyTalking.utils import get_audio_features, resize_image_by_longest_edg
|
|
18 |
|
19 |
|
20 |
def parse_args():
|
21 |
-
parser = argparse.ArgumentParser(description="
|
22 |
-
|
23 |
-
parser.add_argument(
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
parser.add_argument(
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
parser.add_argument(
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
|
42 |
def load_models(args):
|
@@ -61,7 +148,9 @@ def load_models(args):
|
|
61 |
)
|
62 |
print("β
Wan I2V models loaded.")
|
63 |
|
64 |
-
pipe = WanVideoPipeline.from_model_manager(
|
|
|
|
|
65 |
|
66 |
print("π Loading FantasyTalking model...")
|
67 |
fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda")
|
@@ -86,7 +175,7 @@ def main(args, pipe, fantasytalking, wav2vec_processor, wav2vec):
|
|
86 |
print(f"π Getting duration of audio: {args.audio_path}")
|
87 |
duration = librosa.get_duration(filename=args.audio_path)
|
88 |
print(f"ποΈ Duration: {duration:.2f}s")
|
89 |
-
|
90 |
latents_num_frames = min(int(duration * args.fps / 4), args.max_num_frames // 4)
|
91 |
num_frames = (latents_num_frames - 1) * 4
|
92 |
print(f"π½οΈ Calculated number of frames: {num_frames}")
|
@@ -128,7 +217,7 @@ def main(args, pipe, fantasytalking, wav2vec_processor, wav2vec):
|
|
128 |
audio_cfg_scale=args.audio_cfg_scale,
|
129 |
audio_proj=audio_proj_split,
|
130 |
audio_context_lens=audio_context_lens,
|
131 |
-
latents_num_frames=
|
132 |
)
|
133 |
print("β
Video frames generated.")
|
134 |
|
@@ -158,4 +247,4 @@ if __name__ == "__main__":
|
|
158 |
args = parse_args()
|
159 |
pipe, fantasytalking, wav2vec_processor, wav2vec = load_models(args)
|
160 |
video_path = main(args, pipe, fantasytalking, wav2vec_processor, wav2vec)
|
161 |
-
print(f"π Done! Final video path: {video_path}")
|
|
|
18 |
|
19 |
|
20 |
def parse_args():
|
21 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
22 |
+
|
23 |
+
parser.add_argument(
|
24 |
+
"--wan_model_dir",
|
25 |
+
type=str,
|
26 |
+
default="./models/Wan2.1-I2V-14B-720P",
|
27 |
+
required=False,
|
28 |
+
help="The dir of the Wan I2V 14B model.",
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--fantasytalking_model_path",
|
32 |
+
type=str,
|
33 |
+
default="./models/fantasytalking_model.ckpt",
|
34 |
+
required=False,
|
35 |
+
help="The .ckpt path of fantasytalking model.",
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--wav2vec_model_dir",
|
39 |
+
type=str,
|
40 |
+
default="./models/wav2vec2-base-960h",
|
41 |
+
required=False,
|
42 |
+
help="The dir of wav2vec model.",
|
43 |
+
)
|
44 |
+
|
45 |
+
parser.add_argument(
|
46 |
+
"--image_path",
|
47 |
+
type=str,
|
48 |
+
default="./assets/images/woman.png",
|
49 |
+
required=False,
|
50 |
+
help="The path of the image.",
|
51 |
+
)
|
52 |
+
|
53 |
+
parser.add_argument(
|
54 |
+
"--audio_path",
|
55 |
+
type=str,
|
56 |
+
default="./assets/audios/woman.wav",
|
57 |
+
required=False,
|
58 |
+
help="The path of the audio.",
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
"--prompt",
|
62 |
+
type=str,
|
63 |
+
default="A woman is talking.",
|
64 |
+
required=False,
|
65 |
+
help="prompt.",
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--output_dir",
|
69 |
+
type=str,
|
70 |
+
default="./output",
|
71 |
+
help="Dir to save the model.",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--image_size",
|
75 |
+
type=int,
|
76 |
+
default=512,
|
77 |
+
help="The image will be resized proportionally to this size.",
|
78 |
+
)
|
79 |
+
parser.add_argument(
|
80 |
+
"--audio_scale",
|
81 |
+
type=float,
|
82 |
+
default=1.0,
|
83 |
+
help="Audio condition injection weight",
|
84 |
+
)
|
85 |
+
parser.add_argument(
|
86 |
+
"--prompt_cfg_scale",
|
87 |
+
type=float,
|
88 |
+
default=5.0,
|
89 |
+
required=False,
|
90 |
+
help="Prompt cfg scale",
|
91 |
+
)
|
92 |
+
parser.add_argument(
|
93 |
+
"--audio_cfg_scale",
|
94 |
+
type=float,
|
95 |
+
default=5.0,
|
96 |
+
required=False,
|
97 |
+
help="Audio cfg scale",
|
98 |
+
)
|
99 |
+
parser.add_argument(
|
100 |
+
"--max_num_frames",
|
101 |
+
type=int,
|
102 |
+
default=81,
|
103 |
+
required=False,
|
104 |
+
help="The maximum frames for generating videos, the audio part exceeding max_num_frames/fps will be truncated.",
|
105 |
+
)
|
106 |
+
parser.add_argument(
|
107 |
+
"--fps",
|
108 |
+
type=int,
|
109 |
+
default=23,
|
110 |
+
required=False,
|
111 |
+
)
|
112 |
+
parser.add_argument(
|
113 |
+
"--num_persistent_param_in_dit",
|
114 |
+
type=int,
|
115 |
+
default=None,
|
116 |
+
required=False,
|
117 |
+
help="Maximum parameter quantity retained in video memory, small number to reduce VRAM required",
|
118 |
+
)
|
119 |
+
parser.add_argument(
|
120 |
+
"--seed",
|
121 |
+
type=int,
|
122 |
+
default=1111,
|
123 |
+
required=False,
|
124 |
+
)
|
125 |
+
args = parser.parse_args()
|
126 |
+
return args
|
127 |
|
128 |
|
129 |
def load_models(args):
|
|
|
148 |
)
|
149 |
print("β
Wan I2V models loaded.")
|
150 |
|
151 |
+
pipe = WanVideoPipeline.from_model_manager(
|
152 |
+
model_manager, torch_dtype=torch.bfloat16, device="cuda"
|
153 |
+
)
|
154 |
|
155 |
print("π Loading FantasyTalking model...")
|
156 |
fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda")
|
|
|
175 |
print(f"π Getting duration of audio: {args.audio_path}")
|
176 |
duration = librosa.get_duration(filename=args.audio_path)
|
177 |
print(f"ποΈ Duration: {duration:.2f}s")
|
178 |
+
|
179 |
latents_num_frames = min(int(duration * args.fps / 4), args.max_num_frames // 4)
|
180 |
num_frames = (latents_num_frames - 1) * 4
|
181 |
print(f"π½οΈ Calculated number of frames: {num_frames}")
|
|
|
217 |
audio_cfg_scale=args.audio_cfg_scale,
|
218 |
audio_proj=audio_proj_split,
|
219 |
audio_context_lens=audio_context_lens,
|
220 |
+
latents_num_frames=(num_frames - 1) // 4 + 1,
|
221 |
)
|
222 |
print("β
Video frames generated.")
|
223 |
|
|
|
247 |
args = parse_args()
|
248 |
pipe, fantasytalking, wav2vec_processor, wav2vec = load_models(args)
|
249 |
video_path = main(args, pipe, fantasytalking, wav2vec_processor, wav2vec)
|
250 |
+
print(f"π Done! Final video path: {video_path}")
|