wedyanessam commited on
Commit
56e8444
Β·
verified Β·
1 Parent(s): 39618e2

Update FantasyTalking/infer.py

Browse files
Files changed (1) hide show
  1. FantasyTalking/infer.py +112 -23
FantasyTalking/infer.py CHANGED
@@ -18,25 +18,112 @@ from FantasyTalking.utils import get_audio_features, resize_image_by_longest_edg
18
 
19
 
20
  def parse_args():
21
- parser = argparse.ArgumentParser(description="FantasyTalking Video Generator")
22
-
23
- parser.add_argument("--wan_model_dir", type=str, default="./models/Wan2.1-I2V-14B-720P")
24
- parser.add_argument("--fantasytalking_model_path", type=str, default="./models/fantasytalking_model.ckpt")
25
- parser.add_argument("--wav2vec_model_dir", type=str, default="./models/wav2vec2-base-960h")
26
- parser.add_argument("--image_path", type=str, default="./assets/images/woman.png")
27
- parser.add_argument("--audio_path", type=str, default="./assets/audios/woman.wav")
28
- parser.add_argument("--prompt", type=str, default="A woman is talking.")
29
- parser.add_argument("--output_dir", type=str, default="./output")
30
- parser.add_argument("--image_size", type=int, default=512)
31
- parser.add_argument("--audio_scale", type=float, default=1.0)
32
- parser.add_argument("--prompt_cfg_scale", type=float, default=5.0)
33
- parser.add_argument("--audio_cfg_scale", type=float, default=5.0)
34
- parser.add_argument("--max_num_frames", type=int, default=81)
35
- parser.add_argument("--fps", type=int, default=23)
36
- parser.add_argument("--num_persistent_param_in_dit", type=int, default=None)
37
- parser.add_argument("--seed", type=int, default=1111)
38
-
39
- return parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
42
  def load_models(args):
@@ -61,7 +148,9 @@ def load_models(args):
61
  )
62
  print("βœ… Wan I2V models loaded.")
63
 
64
- pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
 
 
65
 
66
  print("πŸ”„ Loading FantasyTalking model...")
67
  fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda")
@@ -86,7 +175,7 @@ def main(args, pipe, fantasytalking, wav2vec_processor, wav2vec):
86
  print(f"πŸ”Š Getting duration of audio: {args.audio_path}")
87
  duration = librosa.get_duration(filename=args.audio_path)
88
  print(f"🎞️ Duration: {duration:.2f}s")
89
-
90
  latents_num_frames = min(int(duration * args.fps / 4), args.max_num_frames // 4)
91
  num_frames = (latents_num_frames - 1) * 4
92
  print(f"πŸ“½οΈ Calculated number of frames: {num_frames}")
@@ -128,7 +217,7 @@ def main(args, pipe, fantasytalking, wav2vec_processor, wav2vec):
128
  audio_cfg_scale=args.audio_cfg_scale,
129
  audio_proj=audio_proj_split,
130
  audio_context_lens=audio_context_lens,
131
- latents_num_frames=latents_num_frames,
132
  )
133
  print("βœ… Video frames generated.")
134
 
@@ -158,4 +247,4 @@ if __name__ == "__main__":
158
  args = parse_args()
159
  pipe, fantasytalking, wav2vec_processor, wav2vec = load_models(args)
160
  video_path = main(args, pipe, fantasytalking, wav2vec_processor, wav2vec)
161
- print(f"πŸŽ‰ Done! Final video path: {video_path}")
 
18
 
19
 
20
  def parse_args():
21
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
22
+
23
+ parser.add_argument(
24
+ "--wan_model_dir",
25
+ type=str,
26
+ default="./models/Wan2.1-I2V-14B-720P",
27
+ required=False,
28
+ help="The dir of the Wan I2V 14B model.",
29
+ )
30
+ parser.add_argument(
31
+ "--fantasytalking_model_path",
32
+ type=str,
33
+ default="./models/fantasytalking_model.ckpt",
34
+ required=False,
35
+ help="The .ckpt path of fantasytalking model.",
36
+ )
37
+ parser.add_argument(
38
+ "--wav2vec_model_dir",
39
+ type=str,
40
+ default="./models/wav2vec2-base-960h",
41
+ required=False,
42
+ help="The dir of wav2vec model.",
43
+ )
44
+
45
+ parser.add_argument(
46
+ "--image_path",
47
+ type=str,
48
+ default="./assets/images/woman.png",
49
+ required=False,
50
+ help="The path of the image.",
51
+ )
52
+
53
+ parser.add_argument(
54
+ "--audio_path",
55
+ type=str,
56
+ default="./assets/audios/woman.wav",
57
+ required=False,
58
+ help="The path of the audio.",
59
+ )
60
+ parser.add_argument(
61
+ "--prompt",
62
+ type=str,
63
+ default="A woman is talking.",
64
+ required=False,
65
+ help="prompt.",
66
+ )
67
+ parser.add_argument(
68
+ "--output_dir",
69
+ type=str,
70
+ default="./output",
71
+ help="Dir to save the model.",
72
+ )
73
+ parser.add_argument(
74
+ "--image_size",
75
+ type=int,
76
+ default=512,
77
+ help="The image will be resized proportionally to this size.",
78
+ )
79
+ parser.add_argument(
80
+ "--audio_scale",
81
+ type=float,
82
+ default=1.0,
83
+ help="Audio condition injection weight",
84
+ )
85
+ parser.add_argument(
86
+ "--prompt_cfg_scale",
87
+ type=float,
88
+ default=5.0,
89
+ required=False,
90
+ help="Prompt cfg scale",
91
+ )
92
+ parser.add_argument(
93
+ "--audio_cfg_scale",
94
+ type=float,
95
+ default=5.0,
96
+ required=False,
97
+ help="Audio cfg scale",
98
+ )
99
+ parser.add_argument(
100
+ "--max_num_frames",
101
+ type=int,
102
+ default=81,
103
+ required=False,
104
+ help="The maximum frames for generating videos, the audio part exceeding max_num_frames/fps will be truncated.",
105
+ )
106
+ parser.add_argument(
107
+ "--fps",
108
+ type=int,
109
+ default=23,
110
+ required=False,
111
+ )
112
+ parser.add_argument(
113
+ "--num_persistent_param_in_dit",
114
+ type=int,
115
+ default=None,
116
+ required=False,
117
+ help="Maximum parameter quantity retained in video memory, small number to reduce VRAM required",
118
+ )
119
+ parser.add_argument(
120
+ "--seed",
121
+ type=int,
122
+ default=1111,
123
+ required=False,
124
+ )
125
+ args = parser.parse_args()
126
+ return args
127
 
128
 
129
  def load_models(args):
 
148
  )
149
  print("βœ… Wan I2V models loaded.")
150
 
151
+ pipe = WanVideoPipeline.from_model_manager(
152
+ model_manager, torch_dtype=torch.bfloat16, device="cuda"
153
+ )
154
 
155
  print("πŸ”„ Loading FantasyTalking model...")
156
  fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda")
 
175
  print(f"πŸ”Š Getting duration of audio: {args.audio_path}")
176
  duration = librosa.get_duration(filename=args.audio_path)
177
  print(f"🎞️ Duration: {duration:.2f}s")
178
+
179
  latents_num_frames = min(int(duration * args.fps / 4), args.max_num_frames // 4)
180
  num_frames = (latents_num_frames - 1) * 4
181
  print(f"πŸ“½οΈ Calculated number of frames: {num_frames}")
 
217
  audio_cfg_scale=args.audio_cfg_scale,
218
  audio_proj=audio_proj_split,
219
  audio_context_lens=audio_context_lens,
220
+ latents_num_frames=(num_frames - 1) // 4 + 1,
221
  )
222
  print("βœ… Video frames generated.")
223
 
 
247
  args = parse_args()
248
  pipe, fantasytalking, wav2vec_processor, wav2vec = load_models(args)
249
  video_path = main(args, pipe, fantasytalking, wav2vec_processor, wav2vec)
250
+ print(f"πŸŽ‰ Done! Final video path: {video_path}")