Spaces:
Runtime error
Runtime error
File size: 4,157 Bytes
906e212 0db1bcd 72b39a2 0db1bcd 72b39a2 0db1bcd 906e212 72b39a2 906e212 72b39a2 34383d4 72b39a2 3b6e0c9 72b39a2 906e212 5f0f47d 906e212 5f0f47d 906e212 f45b35d 906e212 72b39a2 5f0f47d 906e212 72b39a2 906e212 0db1bcd 906e212 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import argparse
import functools
import pathlib
import os
import subprocess
import tarfile
if os.environ.get("SYSTEM") == "spaces":
import mim
mim.uninstall("mmcv-full", confirm_yes=True)
subprocess.call("mim install mmcv-full==1.6.2".split())
subprocess.call("pip uninstall -y opencv-python".split())
subprocess.call("pip uninstall -y opencv-python-headless".split())
subprocess.call("pip install opencv-python-headless==4.7.0.72".split())
import cv2
import gradio as gr
import huggingface_hub
import numpy as np
import PIL.Image
import anime_face_detector
def load_sample_image_paths():
image_dir = pathlib.Path("images")
if not image_dir.exists():
dataset_repo = "hysts/sample-images-TADNE"
path = huggingface_hub.hf_hub_download(
dataset_repo, "images.tar.gz", repo_type="dataset"
)
with tarfile.open(path) as f:
f.extractall()
return sorted(image_dir.glob("*"))
def detect(
img,
face_score_threshold: float,
landmark_score_threshold: float,
detector: anime_face_detector.LandmarkDetector,
) -> PIL.Image.Image:
if not img:
return None
image = cv2.imread(img)
preds = detector(image)
res = image.copy()
for pred in preds:
box = pred["bbox"]
box, score = box[:4], box[4]
if score < face_score_threshold:
continue
box = np.round(box).astype(int)
lt = max(2, int(3 * (box[2:] - box[:2]).max() / 256))
cv2.rectangle(res, tuple(box[:2]), tuple(box[2:]), (0, 255, 0), lt)
pred_pts = pred["keypoints"]
for *pt, score in pred_pts:
if score < landmark_score_threshold:
color = (0, 255, 255)
else:
color = (0, 0, 255)
pt = np.round(pt).astype(int)
cv2.circle(res, tuple(pt), lt, color, cv2.FILLED)
res = cv2.cvtColor(res, cv2.COLOR_BGR2RGB)
image_pil = PIL.Image.fromarray(res)
return image_pil
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--detector", type=str, default="yolov3", choices=["yolov3", "faster-rcnn"]
)
parser.add_argument("--device", type=str, default="cpu", choices=["cuda:0", "cpu"])
parser.add_argument("--face-score-threshold", type=float, default=0.5)
parser.add_argument("--landmark-score-threshold", type=float, default=0.3)
parser.add_argument("--score-slider-step", type=float, default=0.05)
parser.add_argument("--port", type=int)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--share", action="store_true")
parser.add_argument("--live", action="store_true")
args = parser.parse_args()
image_paths = load_sample_image_paths()
examples = [[path.as_posix(), 0.5, 0.3] for path in image_paths]
detector = anime_face_detector.create_detector(args.detector, device=args.device)
func = functools.partial(detect, detector=detector)
title = "edisonlee55/hysts-anime-face-detector"
description = "Demo for edisonlee55/hysts-anime-face-detector. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
article = "<a href='https://github.com/edisonlee55/hysts-anime-face-detector'>GitHub Repo</a>"
gr.Interface(
func,
[
gr.Image(type="filepath", label="Input"),
gr.Slider(
0,
1,
step=args.score_slider_step,
value=args.face_score_threshold,
label="Face Score Threshold",
),
gr.Slider(
0,
1,
step=args.score_slider_step,
value=args.landmark_score_threshold,
label="Landmark Score Threshold",
),
],
gr.Image(type="pil", label="Output"),
title=title,
description=description,
article=article,
examples=examples,
live=args.live,
).launch(debug=args.debug, share=args.share, server_port=args.port)
if __name__ == "__main__":
main()
|