wedyanessam commited on
Commit
a073983
ยท
verified ยท
1 Parent(s): b181f9d

Create inference.py

Browse files
Files changed (1) hide show
  1. FantasyTalking/inference.py +50 -0
FantasyTalking/inference.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from PIL import Image
4
+ import torchvision.transforms as transforms
5
+ from fantasy_talking.model import FantasyTalkingModel
6
+ from moviepy.editor import ImageSequenceClip
7
+ import torchaudio
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # ุชุญู…ูŠู„ ุงู„ู…ูˆุฏูŠู„
12
+ model_ckpt = "./models/fantasytalking_model.ckpt"
13
+ model = FantasyTalkingModel()
14
+ model.load_state_dict(torch.load(model_ckpt, map_location=device))
15
+ model = model.to(device)
16
+ model.eval()
17
+
18
+ # ุชุญูˆูŠู„ ุงู„ุตูˆุฑุฉ ุฅู„ู‰ Tensor
19
+ def load_image(image_path):
20
+ image = Image.open(image_path).convert("RGB")
21
+ transform = transforms.Compose([
22
+ transforms.Resize((512, 512)),
23
+ transforms.ToTensor()
24
+ ])
25
+ return transform(image).unsqueeze(0).to(device)
26
+
27
+ # ุชุญูˆูŠู„ ุงู„ุตูˆุช ุฅู„ู‰ Tensor
28
+ def load_audio(audio_path):
29
+ waveform, sample_rate = torchaudio.load(audio_path)
30
+ if waveform.shape[0] > 1:
31
+ waveform = waveform.mean(dim=0, keepdim=True)
32
+ if sample_rate != 16000:
33
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
34
+ waveform = resampler(waveform)
35
+ return waveform.to(device), 16000
36
+
37
+ # ุชูˆู„ูŠุฏ ุงู„ููŠุฏูŠูˆ
38
+ def generate_video(image_path, audio_path, output_path="output.mp4"):
39
+ image_tensor = load_image(image_path)
40
+ audio_tensor, _ = load_audio(audio_path)
41
+
42
+ with torch.no_grad():
43
+ frames = model.generate(image_tensor, audio_tensor)
44
+
45
+ # ุญูุธ ุงู„ููŠุฏูŠูˆ ู…ู† ุงู„ูุฑูŠู…ุงุช
46
+ frames = [transforms.ToPILImage()(frame.squeeze(0).cpu()) for frame in frames]
47
+ video_clip = ImageSequenceClip([frame for frame in frames], fps=25)
48
+ video_clip.write_videofile(output_path, codec="libx264")
49
+
50
+ return output_path