Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
from huggingface_hub import snapshot_download | |
import os | |
# Define repository and local directory | |
repo_id = "ai-forever/GHOST-2.0-repo" # HF repo | |
local_dir = "./" # Target local directory | |
# Download the entire repository | |
snapshot_download(repo_id=repo_id, local_dir=local_dir, token=os.getenv('HF_TOKEN')) | |
print(f"Repository downloaded to: {local_dir}") | |
import cv2 | |
import torch | |
import argparse | |
import yaml | |
from torchvision import transforms | |
import onnxruntime as ort | |
from PIL import Image | |
from insightface.app import FaceAnalysis | |
from omegaconf import OmegaConf | |
from torchvision.transforms.functional import rgb_to_grayscale | |
from src.utils.crops import * | |
from repos.stylematte.stylematte.models import StyleMatte | |
from src.utils.inference import * | |
from src.utils.inpainter import LamaInpainter | |
from src.utils.preblending import calc_pseudo_target_bg | |
from train_aligner import AlignerModule | |
from train_blender import BlenderModule | |
def infer_headswap(source, target): | |
def calc_mask(img): | |
if isinstance(img, np.ndarray): | |
img = torch.from_numpy(img).permute(2, 0, 1).cuda() | |
if img.max() > 1.: | |
img = img / 255.0 | |
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
input_t = normalize(img) | |
input_t = input_t.unsqueeze(0).float() | |
with torch.no_grad(): | |
out = segment_model(input_t) | |
result = out[0] | |
return result[0] | |
def process_img(img, target=False): | |
full_frames = np.array(img)[:, :, ::-1] | |
dets = app.get(full_frames) | |
if len(dets) == 0: | |
pad_top, pad_bottom, pad_left, pad_right = ( | |
full_frames.shape[0] // 2, full_frames.shape[0] // 2, | |
full_frames.shape[1] // 2, full_frames.shape[1] // 2 | |
) | |
full_frames = cv2.copyMakeBorder( | |
full_frames, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=0) | |
dets = app.get(full_frames) | |
if len(dets) == 0: | |
gr.Warning(f"no head on {'target' if target else 'source'} image") | |
raise gr.Error() | |
kps = dets[0]['kps'] | |
wide = wide_crop_face(full_frames, kps, return_M=target) | |
if target: | |
wide, M = wide | |
arc = norm_crop(full_frames, kps) | |
mask = calc_mask(wide) | |
arc = normalize_and_torch(arc) | |
wide = normalize_and_torch(wide) | |
if target: | |
return wide, arc, mask, full_frames, M | |
return wide, arc, mask | |
wide_source, arc_source, mask_source = process_img(source) | |
wide_target, arc_target, mask_target, full_frame, M = process_img(target, target=True) | |
wide_source = wide_source.unsqueeze(1) | |
arc_source = arc_source.unsqueeze(1) | |
source_mask = mask_source.unsqueeze(0).unsqueeze(0).unsqueeze(0) | |
target_mask = mask_target.unsqueeze(0).unsqueeze(0) | |
X_dict = { | |
'source': { | |
'face_arc': arc_source, | |
'face_wide': wide_source * mask_source, | |
'face_wide_mask': mask_source | |
}, | |
'target': { | |
'face_arc': arc_target, | |
'face_wide': wide_target * mask_target, | |
'face_wide_mask': mask_target | |
} | |
} | |
with torch.no_grad(): | |
output = aligner(X_dict) | |
target_parsing = infer_parsing(wide_target) | |
pseudo_norm_target = calc_pseudo_target_bg(wide_target, target_parsing) | |
soft_mask = calc_mask(((output['fake_rgbs'] * output['fake_segm'])[0, [2, 1, 0], :, :] + 1) / 2)[None] | |
new_source = output['fake_rgbs'] * soft_mask[:, None, ...] + pseudo_norm_target * (1 - soft_mask[:, None, ...]) | |
blender_input = { | |
'face_source': new_source, # output['fake_rgbs']*output['fake_segm'] + norm_target*(1-output['fake_segm']),# face_source, | |
'gray_source': rgb_to_grayscale(new_source[0][[2, 1, 0], ...]).unsqueeze(0), | |
'face_target': wide_target, | |
'mask_source': infer_parsing(output['fake_rgbs']*output['fake_segm']), | |
'mask_target': target_parsing, | |
'mask_source_noise': None, | |
'mask_target_noise': None, | |
'alpha_source': soft_mask | |
} | |
output_b = blender(blender_input, inpainter=inpainter) | |
np_output = np.uint8((output_b['oup'][0].detach().cpu().numpy().transpose((1, 2, 0))[:,:,::-1] / 2 + 0.5)*255) | |
result = copy_head_back(np_output, full_frame[..., ::-1], M) | |
return Image.fromarray(result) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
# Generator params | |
parser.add_argument('--config_a', default='./configs/aligner.yaml', type=str, help='Path to Aligner config') | |
parser.add_argument('--config_b', default='./configs/blender.yaml', type=str, help='Path to Blender config') | |
parser.add_argument('--source', default='./examples/images/hab.jpg', type=str, help='Path to source image') | |
parser.add_argument('--target', default='./examples/images/elon.jpg', type=str, help='Path to target image') | |
parser.add_argument('--ckpt_a', default='./aligner_checkpoints/aligner_1020_gaze_final.ckpt', type=str, help='Aligner checkpoint') | |
parser.add_argument('--ckpt_b', default='./blender_checkpoints/blender_lama.ckpt', type=str, help='Blender checkpoint') | |
parser.add_argument('--save_path', default='result.png', type=str, help='Path to save the result') | |
args = parser.parse_args() | |
with open(args.config_a, "r") as stream: | |
cfg_a = OmegaConf.load(stream) | |
with open(args.config_b, "r") as stream: | |
cfg_b = OmegaConf.load(stream) | |
aligner = AlignerModule(cfg_a) | |
ckpt = torch.load(args.ckpt_a, map_location='cpu') | |
aligner.load_state_dict(torch.load(args.ckpt_a), strict=False) | |
aligner.eval() | |
aligner.cuda() | |
blender = BlenderModule(cfg_b) | |
blender.load_state_dict(torch.load(args.ckpt_b, map_location='cpu')["state_dict"], strict=False,) | |
blender.eval() | |
blender.cuda() | |
inpainter = LamaInpainter('cpu') | |
app = FaceAnalysis(providers=['CUDAExecutionProvider'], allowed_modules=['detection']) | |
app.prepare(ctx_id=0, det_size=(640, 640)) | |
segment_model = StyleMatte() | |
segment_model.load_state_dict( | |
torch.load( | |
'./repos/stylematte/stylematte/checkpoints/stylematte_synth.pth', | |
map_location='cpu' | |
) | |
) | |
segment_model = segment_model.cuda() | |
segment_model.eval() | |
providers = [ | |
("CUDAExecutionProvider", {}) | |
] | |
parsings_session = ort.InferenceSession('./weights/segformer_B5_ce.onnx', providers=providers) | |
input_name = parsings_session.get_inputs()[0].name | |
output_names = [output.name for output in parsings_session.get_outputs()] | |
mean = np.array([0.51315393, 0.48064056, 0.46301059])[None, :, None, None] | |
std = np.array([0.21438347, 0.20799829, 0.20304542])[None, :, None, None] | |
infer_parsing = lambda img: torch.tensor( | |
parsings_session.run(output_names, { | |
input_name: (((img[:, [2, 1, 0], ...] / 2 + 0.5).cpu().detach().numpy() - mean) / std).astype(np.float32) | |
})[0], | |
device='cuda', | |
dtype=torch.float32 | |
) | |
source_pil = Image.open(args.source) | |
target_pil = Image.open(args.target) | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(equal_height=True): | |
input_source = gr.Image( | |
type="pil", | |
label="Input Source" | |
) | |
input_target = gr.Image( | |
type="pil", | |
label="Input Target" | |
) | |
run_button = gr.Button("Generate") | |
with gr.Column(): | |
result = gr.Image(type='pil', label='Image Output') | |
run_button.click( | |
fn=infer_headswap, | |
inputs=[input_source, input_target], | |
outputs=[result] | |
) | |
demo.launch() |