Spaces:
Runtime error
Runtime error
Init
Browse files- .gitignore +162 -0
- anime_face_detector/__init__.py +54 -0
- anime_face_detector/configs/mmdet/faster-rcnn.py +66 -0
- anime_face_detector/configs/mmdet/yolov3.py +47 -0
- anime_face_detector/configs/mmpose/hrnetv2.py +250 -0
- anime_face_detector/detector.py +147 -0
- app.py +120 -0
- requirements.txt +16 -0
.gitignore
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
161 |
+
|
162 |
+
.vscode/
|
anime_face_detector/__init__.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from .detector import LandmarkDetector
|
6 |
+
|
7 |
+
|
8 |
+
def get_config_path(model_name: str) -> pathlib.Path:
|
9 |
+
assert model_name in ['faster-rcnn', 'yolov3', 'hrnetv2']
|
10 |
+
|
11 |
+
package_path = pathlib.Path(__file__).parent.resolve()
|
12 |
+
if model_name in ['faster-rcnn', 'yolov3']:
|
13 |
+
config_dir = package_path / 'configs' / 'mmdet'
|
14 |
+
else:
|
15 |
+
config_dir = package_path / 'configs' / 'mmpose'
|
16 |
+
return config_dir / f'{model_name}.py'
|
17 |
+
|
18 |
+
|
19 |
+
def get_checkpoint_path(model_name: str) -> pathlib.Path:
|
20 |
+
assert model_name in ['faster-rcnn', 'yolov3', 'hrnetv2']
|
21 |
+
if model_name in ['faster-rcnn', 'yolov3']:
|
22 |
+
file_name = f'mmdet_anime-face_{model_name}.pth'
|
23 |
+
else:
|
24 |
+
file_name = f'mmpose_anime-face_{model_name}.pth'
|
25 |
+
|
26 |
+
model_dir = pathlib.Path(torch.hub.get_dir()) / 'checkpoints'
|
27 |
+
model_dir.mkdir(exist_ok=True, parents=True)
|
28 |
+
model_path = model_dir / file_name
|
29 |
+
if not model_path.exists():
|
30 |
+
url = f'https://github.com/hysts/anime-face-detector/releases/download/v0.0.1/{file_name}'
|
31 |
+
torch.hub.download_url_to_file(url, model_path.as_posix())
|
32 |
+
|
33 |
+
return model_path
|
34 |
+
|
35 |
+
|
36 |
+
def create_detector(face_detector_name: str = 'yolov3',
|
37 |
+
landmark_model_name='hrnetv2',
|
38 |
+
device: str = 'cuda:0',
|
39 |
+
flip_test: bool = True,
|
40 |
+
box_scale_factor: float = 1.1) -> LandmarkDetector:
|
41 |
+
assert face_detector_name in ['yolov3', 'faster-rcnn']
|
42 |
+
assert landmark_model_name in ['hrnetv2']
|
43 |
+
detector_config_path = get_config_path(face_detector_name)
|
44 |
+
landmark_config_path = get_config_path(landmark_model_name)
|
45 |
+
detector_checkpoint_path = get_checkpoint_path(face_detector_name)
|
46 |
+
landmark_checkpoint_path = get_checkpoint_path(landmark_model_name)
|
47 |
+
model = LandmarkDetector(landmark_config_path,
|
48 |
+
landmark_checkpoint_path,
|
49 |
+
detector_config_path,
|
50 |
+
detector_checkpoint_path,
|
51 |
+
device=device,
|
52 |
+
flip_test=flip_test,
|
53 |
+
box_scale_factor=box_scale_factor)
|
54 |
+
return model
|
anime_face_detector/configs/mmdet/faster-rcnn.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model = dict(type='FasterRCNN',
|
2 |
+
backbone=dict(type='ResNet',
|
3 |
+
depth=50,
|
4 |
+
num_stages=4,
|
5 |
+
out_indices=(0, 1, 2, 3),
|
6 |
+
frozen_stages=1,
|
7 |
+
norm_cfg=dict(type='BN', requires_grad=True),
|
8 |
+
norm_eval=True,
|
9 |
+
style='pytorch'),
|
10 |
+
neck=dict(type='FPN',
|
11 |
+
in_channels=[256, 512, 1024, 2048],
|
12 |
+
out_channels=256,
|
13 |
+
num_outs=5),
|
14 |
+
rpn_head=dict(type='RPNHead',
|
15 |
+
in_channels=256,
|
16 |
+
feat_channels=256,
|
17 |
+
anchor_generator=dict(type='AnchorGenerator',
|
18 |
+
scales=[8],
|
19 |
+
ratios=[0.5, 1.0, 2.0],
|
20 |
+
strides=[4, 8, 16, 32, 64]),
|
21 |
+
bbox_coder=dict(type='DeltaXYWHBBoxCoder',
|
22 |
+
target_means=[0.0, 0.0, 0.0, 0.0],
|
23 |
+
target_stds=[1.0, 1.0, 1.0, 1.0])),
|
24 |
+
roi_head=dict(
|
25 |
+
type='StandardRoIHead',
|
26 |
+
bbox_roi_extractor=dict(type='SingleRoIExtractor',
|
27 |
+
roi_layer=dict(type='RoIAlign',
|
28 |
+
output_size=7,
|
29 |
+
sampling_ratio=0),
|
30 |
+
out_channels=256,
|
31 |
+
featmap_strides=[4, 8, 16, 32]),
|
32 |
+
bbox_head=dict(type='Shared2FCBBoxHead',
|
33 |
+
in_channels=256,
|
34 |
+
fc_out_channels=1024,
|
35 |
+
roi_feat_size=7,
|
36 |
+
num_classes=1,
|
37 |
+
bbox_coder=dict(
|
38 |
+
type='DeltaXYWHBBoxCoder',
|
39 |
+
target_means=[0.0, 0.0, 0.0, 0.0],
|
40 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
41 |
+
reg_class_agnostic=False)),
|
42 |
+
test_cfg=dict(rpn=dict(nms_pre=1000,
|
43 |
+
max_per_img=1000,
|
44 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
45 |
+
min_bbox_size=0),
|
46 |
+
rcnn=dict(score_thr=0.05,
|
47 |
+
nms=dict(type='nms', iou_threshold=0.5),
|
48 |
+
max_per_img=100)))
|
49 |
+
test_pipeline = [
|
50 |
+
dict(type='LoadImageFromFile'),
|
51 |
+
dict(type='MultiScaleFlipAug',
|
52 |
+
img_scale=(1333, 800),
|
53 |
+
flip=False,
|
54 |
+
transforms=[
|
55 |
+
dict(type='Resize', keep_ratio=True),
|
56 |
+
dict(type='RandomFlip'),
|
57 |
+
dict(type='Normalize',
|
58 |
+
mean=[123.675, 116.28, 103.53],
|
59 |
+
std=[58.395, 57.12, 57.375],
|
60 |
+
to_rgb=True),
|
61 |
+
dict(type='Pad', size_divisor=32),
|
62 |
+
dict(type='DefaultFormatBundle'),
|
63 |
+
dict(type='Collect', keys=['img'])
|
64 |
+
])
|
65 |
+
]
|
66 |
+
data = dict(test=dict(pipeline=test_pipeline))
|
anime_face_detector/configs/mmdet/yolov3.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model = dict(type='YOLOV3',
|
2 |
+
backbone=dict(type='Darknet', depth=53, out_indices=(3, 4, 5)),
|
3 |
+
neck=dict(type='YOLOV3Neck',
|
4 |
+
num_scales=3,
|
5 |
+
in_channels=[1024, 512, 256],
|
6 |
+
out_channels=[512, 256, 128]),
|
7 |
+
bbox_head=dict(type='YOLOV3Head',
|
8 |
+
num_classes=1,
|
9 |
+
in_channels=[512, 256, 128],
|
10 |
+
out_channels=[1024, 512, 256],
|
11 |
+
anchor_generator=dict(type='YOLOAnchorGenerator',
|
12 |
+
base_sizes=[[(116, 90),
|
13 |
+
(156, 198),
|
14 |
+
(373, 326)],
|
15 |
+
[(30, 61),
|
16 |
+
(62, 45),
|
17 |
+
(59, 119)],
|
18 |
+
[(10, 13),
|
19 |
+
(16, 30),
|
20 |
+
(33, 23)]],
|
21 |
+
strides=[32, 16, 8]),
|
22 |
+
bbox_coder=dict(type='YOLOBBoxCoder'),
|
23 |
+
featmap_strides=[32, 16, 8]),
|
24 |
+
test_cfg=dict(nms_pre=1000,
|
25 |
+
min_bbox_size=0,
|
26 |
+
score_thr=0.05,
|
27 |
+
conf_thr=0.005,
|
28 |
+
nms=dict(type='nms', iou_threshold=0.45),
|
29 |
+
max_per_img=100))
|
30 |
+
test_pipeline = [
|
31 |
+
dict(type='LoadImageFromFile'),
|
32 |
+
dict(type='MultiScaleFlipAug',
|
33 |
+
img_scale=(608, 608),
|
34 |
+
flip=False,
|
35 |
+
transforms=[
|
36 |
+
dict(type='Resize', keep_ratio=True),
|
37 |
+
dict(type='RandomFlip'),
|
38 |
+
dict(type='Normalize',
|
39 |
+
mean=[0, 0, 0],
|
40 |
+
std=[255.0, 255.0, 255.0],
|
41 |
+
to_rgb=True),
|
42 |
+
dict(type='Pad', size_divisor=32),
|
43 |
+
dict(type='DefaultFormatBundle'),
|
44 |
+
dict(type='Collect', keys=['img'])
|
45 |
+
])
|
46 |
+
]
|
47 |
+
data = dict(test=dict(pipeline=test_pipeline))
|
anime_face_detector/configs/mmpose/hrnetv2.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
channel_cfg = dict(num_output_channels=28,
|
2 |
+
dataset_joints=28,
|
3 |
+
dataset_channel=[
|
4 |
+
list(range(28)),
|
5 |
+
],
|
6 |
+
inference_channel=list(range(28)))
|
7 |
+
|
8 |
+
model = dict(
|
9 |
+
type='TopDown',
|
10 |
+
backbone=dict(type='HRNet',
|
11 |
+
in_channels=3,
|
12 |
+
extra=dict(stage1=dict(num_modules=1,
|
13 |
+
num_branches=1,
|
14 |
+
block='BOTTLENECK',
|
15 |
+
num_blocks=(4, ),
|
16 |
+
num_channels=(64, )),
|
17 |
+
stage2=dict(num_modules=1,
|
18 |
+
num_branches=2,
|
19 |
+
block='BASIC',
|
20 |
+
num_blocks=(4, 4),
|
21 |
+
num_channels=(18, 36)),
|
22 |
+
stage3=dict(num_modules=4,
|
23 |
+
num_branches=3,
|
24 |
+
block='BASIC',
|
25 |
+
num_blocks=(4, 4, 4),
|
26 |
+
num_channels=(18, 36, 72)),
|
27 |
+
stage4=dict(num_modules=3,
|
28 |
+
num_branches=4,
|
29 |
+
block='BASIC',
|
30 |
+
num_blocks=(4, 4, 4, 4),
|
31 |
+
num_channels=(18, 36, 72, 144),
|
32 |
+
multiscale_output=True),
|
33 |
+
upsample=dict(mode='bilinear',
|
34 |
+
align_corners=False))),
|
35 |
+
keypoint_head=dict(type='TopdownHeatmapSimpleHead',
|
36 |
+
in_channels=[18, 36, 72, 144],
|
37 |
+
in_index=(0, 1, 2, 3),
|
38 |
+
input_transform='resize_concat',
|
39 |
+
out_channels=channel_cfg['num_output_channels'],
|
40 |
+
num_deconv_layers=0,
|
41 |
+
extra=dict(final_conv_kernel=1,
|
42 |
+
num_conv_layers=1,
|
43 |
+
num_conv_kernels=(1, )),
|
44 |
+
loss_keypoint=dict(type='JointsMSELoss',
|
45 |
+
use_target_weight=True)),
|
46 |
+
test_cfg=dict(flip_test=True,
|
47 |
+
post_process='unbiased',
|
48 |
+
shift_heatmap=True,
|
49 |
+
modulate_kernel=11))
|
50 |
+
|
51 |
+
data_cfg = dict(image_size=[256, 256],
|
52 |
+
heatmap_size=[64, 64],
|
53 |
+
num_output_channels=channel_cfg['num_output_channels'],
|
54 |
+
num_joints=channel_cfg['dataset_joints'],
|
55 |
+
dataset_channel=channel_cfg['dataset_channel'],
|
56 |
+
inference_channel=channel_cfg['inference_channel'])
|
57 |
+
|
58 |
+
test_pipeline = [
|
59 |
+
dict(type='LoadImageFromFile'),
|
60 |
+
dict(type='TopDownAffine'),
|
61 |
+
dict(type='ToTensor'),
|
62 |
+
dict(type='NormalizeTensor',
|
63 |
+
mean=[0.485, 0.456, 0.406],
|
64 |
+
std=[0.229, 0.224, 0.225]),
|
65 |
+
dict(type='Collect',
|
66 |
+
keys=['img'],
|
67 |
+
meta_keys=['image_file', 'center', 'scale', 'rotation',
|
68 |
+
'flip_pairs']),
|
69 |
+
]
|
70 |
+
|
71 |
+
dataset_info = dict(dataset_name='anime_face',
|
72 |
+
paper_info=dict(),
|
73 |
+
keypoint_info={
|
74 |
+
0:
|
75 |
+
dict(name='kpt-0',
|
76 |
+
id=0,
|
77 |
+
color=[255, 255, 255],
|
78 |
+
type='',
|
79 |
+
swap='kpt-4'),
|
80 |
+
1:
|
81 |
+
dict(name='kpt-1',
|
82 |
+
id=1,
|
83 |
+
color=[255, 255, 255],
|
84 |
+
type='',
|
85 |
+
swap='kpt-3'),
|
86 |
+
2:
|
87 |
+
dict(name='kpt-2',
|
88 |
+
id=2,
|
89 |
+
color=[255, 255, 255],
|
90 |
+
type='',
|
91 |
+
swap=''),
|
92 |
+
3:
|
93 |
+
dict(name='kpt-3',
|
94 |
+
id=3,
|
95 |
+
color=[255, 255, 255],
|
96 |
+
type='',
|
97 |
+
swap='kpt-1'),
|
98 |
+
4:
|
99 |
+
dict(name='kpt-4',
|
100 |
+
id=4,
|
101 |
+
color=[255, 255, 255],
|
102 |
+
type='',
|
103 |
+
swap='kpt-0'),
|
104 |
+
5:
|
105 |
+
dict(name='kpt-5',
|
106 |
+
id=5,
|
107 |
+
color=[255, 255, 255],
|
108 |
+
type='',
|
109 |
+
swap='kpt-10'),
|
110 |
+
6:
|
111 |
+
dict(name='kpt-6',
|
112 |
+
id=6,
|
113 |
+
color=[255, 255, 255],
|
114 |
+
type='',
|
115 |
+
swap='kpt-9'),
|
116 |
+
7:
|
117 |
+
dict(name='kpt-7',
|
118 |
+
id=7,
|
119 |
+
color=[255, 255, 255],
|
120 |
+
type='',
|
121 |
+
swap='kpt-8'),
|
122 |
+
8:
|
123 |
+
dict(name='kpt-8',
|
124 |
+
id=8,
|
125 |
+
color=[255, 255, 255],
|
126 |
+
type='',
|
127 |
+
swap='kpt-7'),
|
128 |
+
9:
|
129 |
+
dict(name='kpt-9',
|
130 |
+
id=9,
|
131 |
+
color=[255, 255, 255],
|
132 |
+
type='',
|
133 |
+
swap='kpt-6'),
|
134 |
+
10:
|
135 |
+
dict(name='kpt-10',
|
136 |
+
id=10,
|
137 |
+
color=[255, 255, 255],
|
138 |
+
type='',
|
139 |
+
swap='kpt-5'),
|
140 |
+
11:
|
141 |
+
dict(name='kpt-11',
|
142 |
+
id=11,
|
143 |
+
color=[255, 255, 255],
|
144 |
+
type='',
|
145 |
+
swap='kpt-19'),
|
146 |
+
12:
|
147 |
+
dict(name='kpt-12',
|
148 |
+
id=12,
|
149 |
+
color=[255, 255, 255],
|
150 |
+
type='',
|
151 |
+
swap='kpt-18'),
|
152 |
+
13:
|
153 |
+
dict(name='kpt-13',
|
154 |
+
id=13,
|
155 |
+
color=[255, 255, 255],
|
156 |
+
type='',
|
157 |
+
swap='kpt-17'),
|
158 |
+
14:
|
159 |
+
dict(name='kpt-14',
|
160 |
+
id=14,
|
161 |
+
color=[255, 255, 255],
|
162 |
+
type='',
|
163 |
+
swap='kpt-22'),
|
164 |
+
15:
|
165 |
+
dict(name='kpt-15',
|
166 |
+
id=15,
|
167 |
+
color=[255, 255, 255],
|
168 |
+
type='',
|
169 |
+
swap='kpt-21'),
|
170 |
+
16:
|
171 |
+
dict(name='kpt-16',
|
172 |
+
id=16,
|
173 |
+
color=[255, 255, 255],
|
174 |
+
type='',
|
175 |
+
swap='kpt-20'),
|
176 |
+
17:
|
177 |
+
dict(name='kpt-17',
|
178 |
+
id=17,
|
179 |
+
color=[255, 255, 255],
|
180 |
+
type='',
|
181 |
+
swap='kpt-13'),
|
182 |
+
18:
|
183 |
+
dict(name='kpt-18',
|
184 |
+
id=18,
|
185 |
+
color=[255, 255, 255],
|
186 |
+
type='',
|
187 |
+
swap='kpt-12'),
|
188 |
+
19:
|
189 |
+
dict(name='kpt-19',
|
190 |
+
id=19,
|
191 |
+
color=[255, 255, 255],
|
192 |
+
type='',
|
193 |
+
swap='kpt-11'),
|
194 |
+
20:
|
195 |
+
dict(name='kpt-20',
|
196 |
+
id=20,
|
197 |
+
color=[255, 255, 255],
|
198 |
+
type='',
|
199 |
+
swap='kpt-16'),
|
200 |
+
21:
|
201 |
+
dict(name='kpt-21',
|
202 |
+
id=21,
|
203 |
+
color=[255, 255, 255],
|
204 |
+
type='',
|
205 |
+
swap='kpt-15'),
|
206 |
+
22:
|
207 |
+
dict(name='kpt-22',
|
208 |
+
id=22,
|
209 |
+
color=[255, 255, 255],
|
210 |
+
type='',
|
211 |
+
swap='kpt-14'),
|
212 |
+
23:
|
213 |
+
dict(name='kpt-23',
|
214 |
+
id=23,
|
215 |
+
color=[255, 255, 255],
|
216 |
+
type='',
|
217 |
+
swap=''),
|
218 |
+
24:
|
219 |
+
dict(name='kpt-24',
|
220 |
+
id=24,
|
221 |
+
color=[255, 255, 255],
|
222 |
+
type='',
|
223 |
+
swap='kpt-26'),
|
224 |
+
25:
|
225 |
+
dict(name='kpt-25',
|
226 |
+
id=25,
|
227 |
+
color=[255, 255, 255],
|
228 |
+
type='',
|
229 |
+
swap=''),
|
230 |
+
26:
|
231 |
+
dict(name='kpt-26',
|
232 |
+
id=26,
|
233 |
+
color=[255, 255, 255],
|
234 |
+
type='',
|
235 |
+
swap='kpt-24'),
|
236 |
+
27:
|
237 |
+
dict(name='kpt-27',
|
238 |
+
id=27,
|
239 |
+
color=[255, 255, 255],
|
240 |
+
type='',
|
241 |
+
swap='')
|
242 |
+
},
|
243 |
+
skeleton_info={},
|
244 |
+
joint_weights=[1.] * 28,
|
245 |
+
sigmas=[])
|
246 |
+
|
247 |
+
data = dict(test=dict(type='',
|
248 |
+
data_cfg=data_cfg,
|
249 |
+
pipeline=test_pipeline,
|
250 |
+
dataset_info=dataset_info), )
|
anime_face_detector/detector.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import pathlib
|
4 |
+
import warnings
|
5 |
+
from typing import Optional, Union
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import mmcv
|
9 |
+
import numpy as np
|
10 |
+
import torch.nn as nn
|
11 |
+
from mmdet.apis import inference_detector, init_detector
|
12 |
+
from mmpose.apis import inference_top_down_pose_model, init_pose_model
|
13 |
+
from mmpose.datasets import DatasetInfo
|
14 |
+
|
15 |
+
|
16 |
+
class LandmarkDetector:
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
landmark_detector_config_or_path: Union[mmcv.Config, str,
|
20 |
+
pathlib.Path],
|
21 |
+
landmark_detector_checkpoint_path: Union[str, pathlib.Path],
|
22 |
+
face_detector_config_or_path: Optional[Union[mmcv.Config, str,
|
23 |
+
pathlib.Path]] = None,
|
24 |
+
face_detector_checkpoint_path: Optional[Union[
|
25 |
+
str, pathlib.Path]] = None,
|
26 |
+
device: str = 'cuda:0',
|
27 |
+
flip_test: bool = True,
|
28 |
+
box_scale_factor: float = 1.1):
|
29 |
+
landmark_config = self._load_config(landmark_detector_config_or_path)
|
30 |
+
self.dataset_info = DatasetInfo(
|
31 |
+
landmark_config.dataset_info) # type: ignore
|
32 |
+
face_detector_config = self._load_config(face_detector_config_or_path)
|
33 |
+
|
34 |
+
self.landmark_detector = self._init_pose_model(
|
35 |
+
landmark_config, landmark_detector_checkpoint_path, device,
|
36 |
+
flip_test)
|
37 |
+
self.face_detector = self._init_face_detector(
|
38 |
+
face_detector_config, face_detector_checkpoint_path, device)
|
39 |
+
|
40 |
+
self.box_scale_factor = box_scale_factor
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def _load_config(
|
44 |
+
config_or_path: Optional[Union[mmcv.Config, str, pathlib.Path]]
|
45 |
+
) -> Optional[mmcv.Config]:
|
46 |
+
if config_or_path is None or isinstance(config_or_path, mmcv.Config):
|
47 |
+
return config_or_path
|
48 |
+
return mmcv.Config.fromfile(config_or_path)
|
49 |
+
|
50 |
+
@staticmethod
|
51 |
+
def _init_pose_model(config: mmcv.Config,
|
52 |
+
checkpoint_path: Union[str, pathlib.Path],
|
53 |
+
device: str, flip_test: bool) -> nn.Module:
|
54 |
+
if isinstance(checkpoint_path, pathlib.Path):
|
55 |
+
checkpoint_path = checkpoint_path.as_posix()
|
56 |
+
model = init_pose_model(config, checkpoint_path, device=device)
|
57 |
+
model.cfg.model.test_cfg.flip_test = flip_test
|
58 |
+
return model
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def _init_face_detector(config: Optional[mmcv.Config],
|
62 |
+
checkpoint_path: Optional[Union[str,
|
63 |
+
pathlib.Path]],
|
64 |
+
device: str) -> Optional[nn.Module]:
|
65 |
+
if config is not None:
|
66 |
+
if isinstance(checkpoint_path, pathlib.Path):
|
67 |
+
checkpoint_path = checkpoint_path.as_posix()
|
68 |
+
model = init_detector(config, checkpoint_path, device=device)
|
69 |
+
else:
|
70 |
+
model = None
|
71 |
+
return model
|
72 |
+
|
73 |
+
def _detect_faces(self, image: np.ndarray) -> list[np.ndarray]:
|
74 |
+
# predicted boxes using mmdet model have the format of
|
75 |
+
# [x0, y0, x1, y1, score]
|
76 |
+
boxes = inference_detector(self.face_detector, image)[0]
|
77 |
+
# scale boxes by `self.box_scale_factor`
|
78 |
+
boxes = self._update_pred_box(boxes)
|
79 |
+
return boxes
|
80 |
+
|
81 |
+
def _update_pred_box(self, pred_boxes: np.ndarray) -> list[np.ndarray]:
|
82 |
+
boxes = []
|
83 |
+
for pred_box in pred_boxes:
|
84 |
+
box = pred_box[:4]
|
85 |
+
size = box[2:] - box[:2] + 1
|
86 |
+
new_size = size * self.box_scale_factor
|
87 |
+
center = (box[:2] + box[2:]) / 2
|
88 |
+
tl = center - new_size / 2
|
89 |
+
br = tl + new_size
|
90 |
+
pred_box[:4] = np.concatenate([tl, br])
|
91 |
+
boxes.append(pred_box)
|
92 |
+
return boxes
|
93 |
+
|
94 |
+
def _detect_landmarks(
|
95 |
+
self, image: np.ndarray,
|
96 |
+
boxes: list[dict[str, np.ndarray]]) -> list[dict[str, np.ndarray]]:
|
97 |
+
preds, _ = inference_top_down_pose_model(
|
98 |
+
self.landmark_detector,
|
99 |
+
image,
|
100 |
+
boxes,
|
101 |
+
format='xyxy',
|
102 |
+
dataset_info=self.dataset_info,
|
103 |
+
return_heatmap=False)
|
104 |
+
return preds
|
105 |
+
|
106 |
+
@staticmethod
|
107 |
+
def _load_image(
|
108 |
+
image_or_path: Union[np.ndarray, str, pathlib.Path]) -> np.ndarray:
|
109 |
+
if isinstance(image_or_path, np.ndarray):
|
110 |
+
image = image_or_path
|
111 |
+
elif isinstance(image_or_path, str):
|
112 |
+
image = cv2.imread(image_or_path)
|
113 |
+
elif isinstance(image_or_path, pathlib.Path):
|
114 |
+
image = cv2.imread(image_or_path.as_posix())
|
115 |
+
else:
|
116 |
+
raise ValueError
|
117 |
+
return image
|
118 |
+
|
119 |
+
def __call__(
|
120 |
+
self,
|
121 |
+
image_or_path: Union[np.ndarray, str, pathlib.Path],
|
122 |
+
boxes: Optional[list[np.ndarray]] = None
|
123 |
+
) -> list[dict[str, np.ndarray]]:
|
124 |
+
"""Detect face landmarks.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
image_or_path: An image with BGR channel order or an image path.
|
128 |
+
boxes: A list of bounding boxes for faces. Each bounding box
|
129 |
+
should be of the form [x0, y0, x1, y1, [score]].
|
130 |
+
|
131 |
+
Returns: A list of detection results. Each detection result has
|
132 |
+
bounding box of the form [x0, y0, x1, y1, [score]], and landmarks
|
133 |
+
of the form [x, y, score].
|
134 |
+
"""
|
135 |
+
image = self._load_image(image_or_path)
|
136 |
+
if boxes is None:
|
137 |
+
if self.face_detector is not None:
|
138 |
+
boxes = self._detect_faces(image)
|
139 |
+
else:
|
140 |
+
warnings.warn(
|
141 |
+
'Neither the face detector nor the bounding box is '
|
142 |
+
'specified. So the entire image is treated as the face '
|
143 |
+
'region.')
|
144 |
+
h, w = image.shape[:2]
|
145 |
+
boxes = [np.array([0, 0, w - 1, h - 1, 1])]
|
146 |
+
box_list = [{'bbox': box} for box in boxes]
|
147 |
+
return self._detect_landmarks(image, box_list)
|
app.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import functools
|
3 |
+
import pathlib
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import gradio as gr
|
7 |
+
import numpy as np
|
8 |
+
import PIL.Image
|
9 |
+
import torch
|
10 |
+
|
11 |
+
import anime_face_detector
|
12 |
+
|
13 |
+
|
14 |
+
def detect(
|
15 |
+
img,
|
16 |
+
face_score_threshold: float,
|
17 |
+
landmark_score_threshold: float,
|
18 |
+
detector: anime_face_detector.LandmarkDetector,
|
19 |
+
) -> PIL.Image.Image:
|
20 |
+
if not img:
|
21 |
+
return None
|
22 |
+
|
23 |
+
image = cv2.imread(img)
|
24 |
+
preds = detector(image)
|
25 |
+
|
26 |
+
res = image.copy()
|
27 |
+
for pred in preds:
|
28 |
+
box = pred["bbox"]
|
29 |
+
box, score = box[:4], box[4]
|
30 |
+
if score < face_score_threshold:
|
31 |
+
continue
|
32 |
+
box = np.round(box).astype(int)
|
33 |
+
|
34 |
+
lt = max(2, int(3 * (box[2:] - box[:2]).max() / 256))
|
35 |
+
|
36 |
+
cv2.rectangle(res, tuple(box[:2]), tuple(box[2:]), (0, 255, 0), lt)
|
37 |
+
|
38 |
+
pred_pts = pred["keypoints"]
|
39 |
+
for *pt, score in pred_pts:
|
40 |
+
if score < landmark_score_threshold:
|
41 |
+
color = (0, 255, 255)
|
42 |
+
else:
|
43 |
+
color = (0, 0, 255)
|
44 |
+
pt = np.round(pt).astype(int)
|
45 |
+
cv2.circle(res, tuple(pt), lt, color, cv2.FILLED)
|
46 |
+
res = cv2.cvtColor(res, cv2.COLOR_BGR2RGB)
|
47 |
+
|
48 |
+
image_pil = PIL.Image.fromarray(res)
|
49 |
+
return image_pil
|
50 |
+
|
51 |
+
|
52 |
+
def main():
|
53 |
+
parser = argparse.ArgumentParser()
|
54 |
+
parser.add_argument(
|
55 |
+
"--detector", type=str, default="yolov3", choices=["yolov3", "faster-rcnn"]
|
56 |
+
)
|
57 |
+
parser.add_argument(
|
58 |
+
"--device", type=str, default="cuda:0", choices=["cuda:0", "cpu"]
|
59 |
+
)
|
60 |
+
parser.add_argument("--face-score-threshold", type=float, default=0.5)
|
61 |
+
parser.add_argument("--landmark-score-threshold", type=float, default=0.3)
|
62 |
+
parser.add_argument("--score-slider-step", type=float, default=0.05)
|
63 |
+
parser.add_argument("--port", type=int)
|
64 |
+
parser.add_argument("--debug", action="store_true")
|
65 |
+
parser.add_argument("--share", action="store_true")
|
66 |
+
parser.add_argument("--live", action="store_true")
|
67 |
+
args = parser.parse_args()
|
68 |
+
|
69 |
+
sample_path = pathlib.Path("assets/input.jpg")
|
70 |
+
if not sample_path.exists():
|
71 |
+
torch.hub.download_url_to_file(
|
72 |
+
"https://raw.githubusercontent.com/edisonlee55/hysts-anime-face-detector/main/assets/input.jpg",
|
73 |
+
sample_path.as_posix(),
|
74 |
+
)
|
75 |
+
|
76 |
+
detector = anime_face_detector.create_detector(args.detector, device=args.device)
|
77 |
+
func = functools.partial(detect, detector=detector)
|
78 |
+
|
79 |
+
title = "edisonlee55/hysts-anime-face-detector"
|
80 |
+
description = "Demo for edisonlee55/hysts-anime-face-detector. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
|
81 |
+
article = "<a href='https://github.com/edisonlee55/hysts-anime-face-detector'>GitHub Repo</a>"
|
82 |
+
|
83 |
+
gr.Interface(
|
84 |
+
func,
|
85 |
+
[
|
86 |
+
gr.Image(type="filepath", label="Input"),
|
87 |
+
gr.Slider(
|
88 |
+
0,
|
89 |
+
1,
|
90 |
+
step=args.score_slider_step,
|
91 |
+
value=args.face_score_threshold,
|
92 |
+
label="Face Score Threshold",
|
93 |
+
),
|
94 |
+
gr.Slider(
|
95 |
+
0,
|
96 |
+
1,
|
97 |
+
step=args.score_slider_step,
|
98 |
+
value=args.landmark_score_threshold,
|
99 |
+
label="Landmark Score Threshold",
|
100 |
+
),
|
101 |
+
],
|
102 |
+
gr.Image(type="pil", label="Output"),
|
103 |
+
title=title,
|
104 |
+
description=description,
|
105 |
+
article=article,
|
106 |
+
examples=[
|
107 |
+
[
|
108 |
+
sample_path.as_posix(),
|
109 |
+
args.face_score_threshold,
|
110 |
+
args.landmark_score_threshold,
|
111 |
+
],
|
112 |
+
],
|
113 |
+
live=args.live,
|
114 |
+
).launch(
|
115 |
+
debug=args.debug, share=args.share, server_port=args.port, enable_queue=True
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python=3.10
|
2 |
+
openmim==0.3.7
|
3 |
+
mmcv-full==1.6.2
|
4 |
+
mmdet==2.28.2
|
5 |
+
mmpose==0.29.0
|
6 |
+
|
7 |
+
numpy==1.24.3
|
8 |
+
scipy==1.10.1
|
9 |
+
|
10 |
+
opencv-python-headless==4.7.0.72
|
11 |
+
|
12 |
+
torch==2.0.1
|
13 |
+
torchvision==0.15.2
|
14 |
+
|
15 |
+
# for gradio
|
16 |
+
# gradio==3.32.0
|