GHOST-2.0 / app.py
ai-forever's picture
Update app.py
9633c48 verified
import spaces
import gradio as gr
from huggingface_hub import snapshot_download
import os
# Define repository and local directory
repo_id = "ai-forever/GHOST-2.0-repo" # HF repo
local_dir = "./" # Target local directory
# Download the entire repository
snapshot_download(repo_id=repo_id, local_dir=local_dir, token=os.getenv('HF_TOKEN'))
print(f"Repository downloaded to: {local_dir}")
import cv2
import torch
import argparse
import yaml
from torchvision import transforms
import onnxruntime as ort
from PIL import Image
from insightface.app import FaceAnalysis
from omegaconf import OmegaConf
from torchvision.transforms.functional import rgb_to_grayscale
from src.utils.crops import *
from repos.stylematte.stylematte.models import StyleMatte
from src.utils.inference import *
from src.utils.inpainter import LamaInpainter
from src.utils.preblending import calc_pseudo_target_bg
from train_aligner import AlignerModule
from train_blender import BlenderModule
@spaces.GPU
def infer_headswap(source, target):
def calc_mask(img):
if isinstance(img, np.ndarray):
img = torch.from_numpy(img).permute(2, 0, 1).cuda()
if img.max() > 1.:
img = img / 255.0
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
input_t = normalize(img)
input_t = input_t.unsqueeze(0).float()
with torch.no_grad():
out = segment_model(input_t)
result = out[0]
return result[0]
def process_img(img, target=False):
full_frames = np.array(img)[:, :, ::-1]
dets = app.get(full_frames)
if len(dets) == 0:
pad_top, pad_bottom, pad_left, pad_right = (
full_frames.shape[0] // 2, full_frames.shape[0] // 2,
full_frames.shape[1] // 2, full_frames.shape[1] // 2
)
full_frames = cv2.copyMakeBorder(
full_frames, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=0)
dets = app.get(full_frames)
if len(dets) == 0:
gr.Warning(f"no head on {'target' if target else 'source'} image")
raise gr.Error()
kps = dets[0]['kps']
wide = wide_crop_face(full_frames, kps, return_M=target)
if target:
wide, M = wide
arc = norm_crop(full_frames, kps)
mask = calc_mask(wide)
arc = normalize_and_torch(arc)
wide = normalize_and_torch(wide)
if target:
return wide, arc, mask, full_frames, M
return wide, arc, mask
wide_source, arc_source, mask_source = process_img(source)
wide_target, arc_target, mask_target, full_frame, M = process_img(target, target=True)
wide_source = wide_source.unsqueeze(1)
arc_source = arc_source.unsqueeze(1)
source_mask = mask_source.unsqueeze(0).unsqueeze(0).unsqueeze(0)
target_mask = mask_target.unsqueeze(0).unsqueeze(0)
X_dict = {
'source': {
'face_arc': arc_source,
'face_wide': wide_source * mask_source,
'face_wide_mask': mask_source
},
'target': {
'face_arc': arc_target,
'face_wide': wide_target * mask_target,
'face_wide_mask': mask_target
}
}
with torch.no_grad():
output = aligner(X_dict)
target_parsing = infer_parsing(wide_target)
pseudo_norm_target = calc_pseudo_target_bg(wide_target, target_parsing)
soft_mask = calc_mask(((output['fake_rgbs'] * output['fake_segm'])[0, [2, 1, 0], :, :] + 1) / 2)[None]
new_source = output['fake_rgbs'] * soft_mask[:, None, ...] + pseudo_norm_target * (1 - soft_mask[:, None, ...])
blender_input = {
'face_source': new_source, # output['fake_rgbs']*output['fake_segm'] + norm_target*(1-output['fake_segm']),# face_source,
'gray_source': rgb_to_grayscale(new_source[0][[2, 1, 0], ...]).unsqueeze(0),
'face_target': wide_target,
'mask_source': infer_parsing(output['fake_rgbs']*output['fake_segm']),
'mask_target': target_parsing,
'mask_source_noise': None,
'mask_target_noise': None,
'alpha_source': soft_mask
}
output_b = blender(blender_input, inpainter=inpainter)
np_output = np.uint8((output_b['oup'][0].detach().cpu().numpy().transpose((1, 2, 0))[:,:,::-1] / 2 + 0.5)*255)
result = copy_head_back(np_output, full_frame[..., ::-1], M)
return Image.fromarray(result)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Generator params
parser.add_argument('--config_a', default='./configs/aligner.yaml', type=str, help='Path to Aligner config')
parser.add_argument('--config_b', default='./configs/blender.yaml', type=str, help='Path to Blender config')
parser.add_argument('--source', default='./examples/images/hab.jpg', type=str, help='Path to source image')
parser.add_argument('--target', default='./examples/images/elon.jpg', type=str, help='Path to target image')
parser.add_argument('--ckpt_a', default='./aligner_checkpoints/aligner_1020_gaze_final.ckpt', type=str, help='Aligner checkpoint')
parser.add_argument('--ckpt_b', default='./blender_checkpoints/blender_lama.ckpt', type=str, help='Blender checkpoint')
parser.add_argument('--save_path', default='result.png', type=str, help='Path to save the result')
args = parser.parse_args()
with open(args.config_a, "r") as stream:
cfg_a = OmegaConf.load(stream)
with open(args.config_b, "r") as stream:
cfg_b = OmegaConf.load(stream)
aligner = AlignerModule(cfg_a)
ckpt = torch.load(args.ckpt_a, map_location='cpu')
aligner.load_state_dict(torch.load(args.ckpt_a), strict=False)
aligner.eval()
aligner.cuda()
blender = BlenderModule(cfg_b)
blender.load_state_dict(torch.load(args.ckpt_b, map_location='cpu')["state_dict"], strict=False,)
blender.eval()
blender.cuda()
inpainter = LamaInpainter('cpu')
app = FaceAnalysis(providers=['CUDAExecutionProvider'], allowed_modules=['detection'])
app.prepare(ctx_id=0, det_size=(640, 640))
segment_model = StyleMatte()
segment_model.load_state_dict(
torch.load(
'./repos/stylematte/stylematte/checkpoints/stylematte_synth.pth',
map_location='cpu'
)
)
segment_model = segment_model.cuda()
segment_model.eval()
providers = [
("CUDAExecutionProvider", {})
]
parsings_session = ort.InferenceSession('./weights/segformer_B5_ce.onnx', providers=providers)
input_name = parsings_session.get_inputs()[0].name
output_names = [output.name for output in parsings_session.get_outputs()]
mean = np.array([0.51315393, 0.48064056, 0.46301059])[None, :, None, None]
std = np.array([0.21438347, 0.20799829, 0.20304542])[None, :, None, None]
infer_parsing = lambda img: torch.tensor(
parsings_session.run(output_names, {
input_name: (((img[:, [2, 1, 0], ...] / 2 + 0.5).cpu().detach().numpy() - mean) / std).astype(np.float32)
})[0],
device='cuda',
dtype=torch.float32
)
source_pil = Image.open(args.source)
target_pil = Image.open(args.target)
with gr.Blocks() as demo:
with gr.Column():
with gr.Row():
with gr.Column():
with gr.Row(equal_height=True):
input_source = gr.Image(
type="pil",
label="Input Source"
)
input_target = gr.Image(
type="pil",
label="Input Target"
)
run_button = gr.Button("Generate")
with gr.Column():
result = gr.Image(type='pil', label='Image Output')
run_button.click(
fn=infer_headswap,
inputs=[input_source, input_target],
outputs=[result]
)
demo.launch()