wedyanessam commited on
Commit
36fe92a
ยท
verified ยท
1 Parent(s): cb36cbf

Delete FantasyTalking/inference.py

Browse files
Files changed (1) hide show
  1. FantasyTalking/inference.py +0 -50
FantasyTalking/inference.py DELETED
@@ -1,50 +0,0 @@
1
- import os
2
- import torch
3
- from PIL import Image
4
- import torchvision.transforms as transforms
5
- from fantasy_talking.model import FantasyTalkingModel
6
- from moviepy.editor import ImageSequenceClip
7
- import torchaudio
8
-
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
-
11
- # ุชุญู…ูŠู„ ุงู„ู…ูˆุฏูŠู„
12
- model_ckpt = "./models/fantasytalking_model.ckpt"
13
- model = FantasyTalkingModel()
14
- model.load_state_dict(torch.load(model_ckpt, map_location=device))
15
- model = model.to(device)
16
- model.eval()
17
-
18
- # ุชุญูˆูŠู„ ุงู„ุตูˆุฑุฉ ุฅู„ู‰ Tensor
19
- def load_image(image_path):
20
- image = Image.open(image_path).convert("RGB")
21
- transform = transforms.Compose([
22
- transforms.Resize((512, 512)),
23
- transforms.ToTensor()
24
- ])
25
- return transform(image).unsqueeze(0).to(device)
26
-
27
- # ุชุญูˆูŠู„ ุงู„ุตูˆุช ุฅู„ู‰ Tensor
28
- def load_audio(audio_path):
29
- waveform, sample_rate = torchaudio.load(audio_path)
30
- if waveform.shape[0] > 1:
31
- waveform = waveform.mean(dim=0, keepdim=True)
32
- if sample_rate != 16000:
33
- resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
34
- waveform = resampler(waveform)
35
- return waveform.to(device), 16000
36
-
37
- # ุชูˆู„ูŠุฏ ุงู„ููŠุฏูŠูˆ
38
- def generate_video(image_path, audio_path, output_path="output.mp4"):
39
- image_tensor = load_image(image_path)
40
- audio_tensor, _ = load_audio(audio_path)
41
-
42
- with torch.no_grad():
43
- frames = model.generate(image_tensor, audio_tensor)
44
-
45
- # ุญูุธ ุงู„ููŠุฏูŠูˆ ู…ู† ุงู„ูุฑูŠู…ุงุช
46
- frames = [transforms.ToPILImage()(frame.squeeze(0).cpu()) for frame in frames]
47
- video_clip = ImageSequenceClip([frame for frame in frames], fps=25)
48
- video_clip.write_videofile(output_path, codec="libx264")
49
-
50
- return output_path