wedyanessam commited on
Commit
df3d223
·
verified ·
1 Parent(s): bd88065

Update FantasyTalking/infer.py

Browse files
Files changed (1) hide show
  1. FantasyTalking/infer.py +38 -25
FantasyTalking/infer.py CHANGED
@@ -127,7 +127,7 @@ def parse_args():
127
 
128
 
129
  def load_models(args):
130
- # Load Wan I2V models
131
  model_manager = ModelManager(device="cpu")
132
  model_manager.load_models(
133
  [
@@ -144,50 +144,63 @@ def load_models(args):
144
  f"{args.wan_model_dir}/models_t5_umt5-xxl-enc-bf16.pth",
145
  f"{args.wan_model_dir}/Wan2.1_VAE.pth",
146
  ],
147
- # torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
148
- torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
149
  )
 
 
150
  pipe = WanVideoPipeline.from_model_manager(
151
  model_manager, torch_dtype=torch.bfloat16, device="cuda"
152
  )
153
 
154
- # Load FantasyTalking weights
155
  fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda")
156
  fantasytalking.load_audio_processor(args.fantasytalking_model_path, pipe.dit)
 
157
 
158
- # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
159
- pipe.enable_vram_management(
160
- num_persistent_param_in_dit=args.num_persistent_param_in_dit
161
- )
162
 
163
- # Load wav2vec models
164
  wav2vec_processor = Wav2Vec2Processor.from_pretrained(args.wav2vec_model_dir)
165
  wav2vec = Wav2Vec2Model.from_pretrained(args.wav2vec_model_dir).to("cuda")
 
166
 
167
  return pipe, fantasytalking, wav2vec_processor, wav2vec
168
 
169
 
170
  def main(args, pipe, fantasytalking, wav2vec_processor, wav2vec):
 
171
  os.makedirs(args.output_dir, exist_ok=True)
172
 
 
173
  duration = librosa.get_duration(filename=args.audio_path)
 
 
174
  num_frames = min(int(args.fps * duration // 4) * 4 + 5, args.max_num_frames)
 
175
 
 
176
  audio_wav2vec_fea = get_audio_features(
177
  wav2vec, wav2vec_processor, args.audio_path, args.fps, num_frames
178
  )
 
 
 
179
  image = resize_image_by_longest_edge(args.image_path, args.image_size)
180
  width, height = image.size
 
181
 
 
182
  audio_proj_fea = fantasytalking.get_proj_fea(audio_wav2vec_fea)
183
  pos_idx_ranges = fantasytalking.split_audio_sequence(
184
  audio_proj_fea.size(1), num_frames=num_frames
185
  )
186
  audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding(
187
  audio_proj_fea, pos_idx_ranges, expand_length=4
188
- ) # [b,21,9+8,768]
 
189
 
190
- # Image-to-video
191
  video_audio = pipe(
192
  prompt=args.prompt,
193
  negative_prompt="人物静止不动,静止,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
@@ -205,32 +218,32 @@ def main(args, pipe, fantasytalking, wav2vec_processor, wav2vec):
205
  audio_context_lens=audio_context_lens,
206
  latents_num_frames=(num_frames - 1) // 4 + 1,
207
  )
 
 
208
  current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
209
  save_path_tmp = f"{args.output_dir}/tmp_{Path(args.image_path).stem}_{Path(args.audio_path).stem}_{current_time}.mp4"
 
210
  save_video(video_audio, save_path_tmp, fps=args.fps, quality=5)
211
 
212
  save_path = f"{args.output_dir}/{Path(args.image_path).stem}_{Path(args.audio_path).stem}_{current_time}.mp4"
 
 
213
  final_command = [
214
- "ffmpeg",
215
- "-y",
216
- "-i",
217
- save_path_tmp,
218
- "-i",
219
- args.audio_path,
220
- "-c:v",
221
- "libx264",
222
- "-c:a",
223
- "aac",
224
- "-shortest",
225
- save_path,
226
  ]
227
  subprocess.run(final_command, check=True)
 
 
 
228
  os.remove(save_path_tmp)
 
229
  return save_path
230
 
231
 
232
  if __name__ == "__main__":
 
233
  args = parse_args()
234
  pipe, fantasytalking, wav2vec_processor, wav2vec = load_models(args)
235
-
236
- main(args, pipe, fantasytalking, wav2vec_processor, wav2vec)
 
127
 
128
 
129
  def load_models(args):
130
+ print("🔄 Loading Wan I2V models...")
131
  model_manager = ModelManager(device="cpu")
132
  model_manager.load_models(
133
  [
 
144
  f"{args.wan_model_dir}/models_t5_umt5-xxl-enc-bf16.pth",
145
  f"{args.wan_model_dir}/Wan2.1_VAE.pth",
146
  ],
147
+ torch_dtype=torch.bfloat16,
 
148
  )
149
+ print("✅ Wan I2V models loaded.")
150
+
151
  pipe = WanVideoPipeline.from_model_manager(
152
  model_manager, torch_dtype=torch.bfloat16, device="cuda"
153
  )
154
 
155
+ print("🔄 Loading FantasyTalking model...")
156
  fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda")
157
  fantasytalking.load_audio_processor(args.fantasytalking_model_path, pipe.dit)
158
+ print("✅ FantasyTalking model loaded.")
159
 
160
+ print("🧠 Enabling VRAM management...")
161
+ pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit)
 
 
162
 
163
+ print("🔄 Loading Wav2Vec2 processor and model...")
164
  wav2vec_processor = Wav2Vec2Processor.from_pretrained(args.wav2vec_model_dir)
165
  wav2vec = Wav2Vec2Model.from_pretrained(args.wav2vec_model_dir).to("cuda")
166
+ print("✅ Wav2Vec2 loaded.")
167
 
168
  return pipe, fantasytalking, wav2vec_processor, wav2vec
169
 
170
 
171
  def main(args, pipe, fantasytalking, wav2vec_processor, wav2vec):
172
+ print("📁 Creating output directory...")
173
  os.makedirs(args.output_dir, exist_ok=True)
174
 
175
+ print(f"🔊 Getting duration of audio: {args.audio_path}")
176
  duration = librosa.get_duration(filename=args.audio_path)
177
+ print(f"🎞️ Duration: {duration:.2f}s")
178
+
179
  num_frames = min(int(args.fps * duration // 4) * 4 + 5, args.max_num_frames)
180
+ print(f"📽️ Calculated number of frames: {num_frames}")
181
 
182
+ print("🎧 Extracting audio features...")
183
  audio_wav2vec_fea = get_audio_features(
184
  wav2vec, wav2vec_processor, args.audio_path, args.fps, num_frames
185
  )
186
+ print("✅ Audio features extracted.")
187
+
188
+ print("🖼️ Loading and resizing image...")
189
  image = resize_image_by_longest_edge(args.image_path, args.image_size)
190
  width, height = image.size
191
+ print(f"✅ Image resized to: {width}x{height}")
192
 
193
+ print("🔄 Projecting audio features...")
194
  audio_proj_fea = fantasytalking.get_proj_fea(audio_wav2vec_fea)
195
  pos_idx_ranges = fantasytalking.split_audio_sequence(
196
  audio_proj_fea.size(1), num_frames=num_frames
197
  )
198
  audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding(
199
  audio_proj_fea, pos_idx_ranges, expand_length=4
200
+ )
201
+ print("✅ Audio features projected and split.")
202
 
203
+ print("🚀 Generating video from image + audio...")
204
  video_audio = pipe(
205
  prompt=args.prompt,
206
  negative_prompt="人物静止不动,静止,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
 
218
  audio_context_lens=audio_context_lens,
219
  latents_num_frames=(num_frames - 1) // 4 + 1,
220
  )
221
+ print("✅ Video frames generated.")
222
+
223
  current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
224
  save_path_tmp = f"{args.output_dir}/tmp_{Path(args.image_path).stem}_{Path(args.audio_path).stem}_{current_time}.mp4"
225
+ print(f"💾 Saving temporary video without audio to: {save_path_tmp}")
226
  save_video(video_audio, save_path_tmp, fps=args.fps, quality=5)
227
 
228
  save_path = f"{args.output_dir}/{Path(args.image_path).stem}_{Path(args.audio_path).stem}_{current_time}.mp4"
229
+ print(f"🔊 Merging video with audio using FFmpeg...")
230
+
231
  final_command = [
232
+ "ffmpeg", "-y", "-i", save_path_tmp, "-i", args.audio_path,
233
+ "-c:v", "libx264", "-c:a", "aac", "-shortest", save_path,
 
 
 
 
 
 
 
 
 
 
234
  ]
235
  subprocess.run(final_command, check=True)
236
+ print(f"✅ Final video saved to: {save_path}")
237
+
238
+ print("🧹 Removing temporary video file...")
239
  os.remove(save_path_tmp)
240
+
241
  return save_path
242
 
243
 
244
  if __name__ == "__main__":
245
+ print("🚦 Starting main script...")
246
  args = parse_args()
247
  pipe, fantasytalking, wav2vec_processor, wav2vec = load_models(args)
248
+ video_path = main(args, pipe, fantasytalking, wav2vec_processor, wav2vec)
249
+ print(f"🎉 Done! Final video path: {video_path}")