File size: 5,216 Bytes
6672810
d6754a0
 
2892b55
 
 
 
 
d6754a0
6672810
 
 
b558187
6672810
2c5167f
6672810
2c5167f
6672810
 
 
 
2c5167f
 
 
 
 
 
 
 
 
 
 
 
 
 
6672810
 
 
 
 
 
 
 
2c5167f
 
 
 
 
 
6672810
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c5167f
 
 
 
 
d55cf7e
6672810
 
 
 
 
 
 
 
 
d55cf7e
6672810
 
 
b558187
eeb0f60
b558187
eeb0f60
 
b558187
 
 
 
cc3443c
eeb0f60
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import subprocess

# subprocess.run(
#     "pip install dlib==19.24.2",
#     capture_output=True,
#     shell=True,
# )

import sys
import face_detection
import PIL
from PIL import Image, ImageOps, ImageFile
import numpy as np
import cv2 as cv
import torch

torch.set_grad_enabled(False)
model = torch.jit.load('u2net_bce_itr_16000_train_3.835149_tar_0.542587-400x_360x.jit.pt')
model.eval()

# https://en.wikipedia.org/wiki/Unsharp_masking
# https://stackoverflow.com/a/55590133/1495606
def unsharp_mask(image, kernel_size=(5, 5), sigma=1.0, amount=2.0, threshold=0):
    """Return a sharpened version of the image, using an unsharp mask."""
    blurred = cv.GaussianBlur(image, kernel_size, sigma)
    sharpened = float(amount + 1) * image - float(amount) * blurred
    sharpened = np.maximum(sharpened, np.zeros(sharpened.shape))
    sharpened = np.minimum(sharpened, 255 * np.ones(sharpened.shape))
    sharpened = sharpened.round().astype(np.uint8)
    if threshold > 0:
        low_contrast_mask = np.absolute(image - blurred) < threshold
        np.copyto(sharpened, image, where=low_contrast_mask)
    return sharpened

def normPRED(d):
    ma = np.max(d)
    mi = np.min(d)

    dn = (d-mi)/(ma-mi)

    return dn

def array_to_np(array_in):
    array_in = normPRED(array_in)
    array_in = np.squeeze(255.0*(array_in)) 
    array_in = np.transpose(array_in, (1, 2, 0))
    return array_in

def array_to_image(array_in):    
    array_in = normPRED(array_in)
    array_in = np.squeeze(255.0*(array_in)) 
    array_in = np.transpose(array_in, (1, 2, 0))
    im = Image.fromarray(array_in.astype(np.uint8))
    return im


def image_as_array(image_in):
    image_in = np.array(image_in, np.float32)
    tmpImg = np.zeros((image_in.shape[0],image_in.shape[1],3))
    image_in = image_in/np.max(image_in)
    if image_in.shape[2]==1:
        tmpImg[:,:,0] = (image_in[:,:,0]-0.485)/0.229
        tmpImg[:,:,1] = (image_in[:,:,0]-0.485)/0.229
        tmpImg[:,:,2] = (image_in[:,:,0]-0.485)/0.229
    else:
        tmpImg[:,:,0] = (image_in[:,:,0]-0.485)/0.229
        tmpImg[:,:,1] = (image_in[:,:,1]-0.456)/0.224
        tmpImg[:,:,2] = (image_in[:,:,2]-0.406)/0.225

    tmpImg = tmpImg.transpose((2, 0, 1))
    image_out = np.expand_dims(tmpImg, 0)
    return image_out

def find_aligned_face(image_in, size=400):   
    aligned_image, n_faces, quad = face_detection.align(image_in, face_index=0, output_size=size)
    return aligned_image, n_faces, quad

def align_first_face(image_in, size=400):
    aligned_image, n_faces, quad = find_aligned_face(image_in,size=size)
    if n_faces == 0:
        try:
            image_in = ImageOps.exif_transpose(image_in)
        except:
            print("exif problem, not rotating")
        image_in = image_in.resize((size, size))
        im_array = image_as_array(image_in)
    else:
        im_array = image_as_array(aligned_image)

    return im_array

def img_concat_h(im1, im2):
    dst = Image.new('RGB', (im1.width + im2.width, im1.height))
    dst.paste(im1, (0, 0))
    dst.paste(im2, (im1.width, 0))
    return dst

import gradio as gr

def face2hero(
    img: Image.Image,
    size: int
) -> Image.Image:

    aligned_img = align_first_face(img)
    if aligned_img is None:
        output=None
    else:
        input = torch.Tensor(aligned_img)
        results = model(input)
        hero_np_image = array_to_np(results[1].detach().numpy())
        hero_image = unsharp_mask(hero_np_image)
        hero_image = Image.fromarray(hero_image)
        
        output = img_concat_h(array_to_image(aligned_img), hero_image)         
        del results

    return output

def inference(img):
    out = face2hero(img, 400)
    return out
      
  
title = "Comics hero U2Net"
description = "Style transfer a face into one of a \"Comics hero\". Upload an image with a face, or click on one of the examples below. If a face could not be detected, an image will still be created."
article = "<hr><p style='text-align: center'>See the <a href='https://github.com/Norod/U-2-Net-StyleTransfer' target='_blank'>Github Repo</a></p><p style='text-align: center'>samples: <img src='https://hf.space/gradioiframe/Norod78/ComicsHeroU2Net/file/Sample00001.jpg' alt='Sample00001'/><img src='https://hf.space/gradioiframe/Norod78/ComicsHeroU2Net/file/Sample00002.jpg' alt='Sample00002'/><img src='https://hf.space/gradioiframe/Norod78/ComicsHeroU2Net/file/Sample00003.jpg' alt='Sample00003'/><img src='https://hf.space/gradioiframe/Norod78/ComicsHeroU2Net/file/Sample00004.jpg' alt='Sample00004'/><img src='https://hf.space/gradioiframe/Norod78/ComicsHeroU2Net/file/Sample00005.jpg' alt='Sample00005'/></p><p>The \"Comics Hero (U2Net)\" model was trained by <a href='https://linktr.ee/Norod78' target='_blank'>Doron Adler</a></p>"

examples=[['Example00001.jpg'],['Example00002.jpg'],['Example00003.jpg'],['Example00004.jpg'],['Example00005.jpg'], ['Example00006.jpg']]

demo = gr.Interface(
    inference, 
    inputs=[gr.Image(type="pil", label="Input")],
    outputs=[gr.Image(type="pil", label="Output")],
    title=title,
    description=description,
    article=article,
    examples=examples,
    allow_flagging="never"
    )

demo.queue()
demo.launch()