Spaces:
Running
Running
Commit
·
1d213d9
1
Parent(s):
be33c96
update app.py
Browse files- app.py +11 -4
- demo/{locals.py → locales.py} +0 -0
- demo/processor.py +83 -30
- demo/ui.py +2 -2
- demo/utils.py +0 -3
- hivision/creator/human_matting.py +15 -6
app.py
CHANGED
|
@@ -15,10 +15,17 @@ HUMAN_MATTING_MODELS_EXIST = [
|
|
| 15 |
if file.endswith(".onnx") or file.endswith(".mnn")
|
| 16 |
]
|
| 17 |
# 在HUMAN_MATTING_MODELS中的模型才会被加载到Gradio中显示
|
| 18 |
-
|
| 19 |
model for model in HUMAN_MATTING_MODELS if model in HUMAN_MATTING_MODELS_EXIST
|
| 20 |
]
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
FACE_DETECT_MODELS = ["face++ (联网Online API)", "mtcnn"]
|
| 23 |
FACE_DETECT_MODELS_EXPAND = (
|
| 24 |
["retinaface-resnet50"]
|
|
@@ -29,7 +36,7 @@ FACE_DETECT_MODELS_EXPAND = (
|
|
| 29 |
)
|
| 30 |
else []
|
| 31 |
)
|
| 32 |
-
FACE_DETECT_MODELS
|
| 33 |
|
| 34 |
LANGUAGE = ["zh", "en", "ko", "ja"]
|
| 35 |
|
|
@@ -54,8 +61,8 @@ if __name__ == "__main__":
|
|
| 54 |
demo = create_ui(
|
| 55 |
processor,
|
| 56 |
root_dir,
|
| 57 |
-
|
| 58 |
-
|
| 59 |
LANGUAGE,
|
| 60 |
)
|
| 61 |
demo.launch(
|
|
|
|
| 15 |
if file.endswith(".onnx") or file.endswith(".mnn")
|
| 16 |
]
|
| 17 |
# 在HUMAN_MATTING_MODELS中的模型才会被加载到Gradio中显示
|
| 18 |
+
HUMAN_MATTING_MODELS_CHOICE = [
|
| 19 |
model for model in HUMAN_MATTING_MODELS if model in HUMAN_MATTING_MODELS_EXIST
|
| 20 |
]
|
| 21 |
|
| 22 |
+
if len(HUMAN_MATTING_MODELS_CHOICE) == 0:
|
| 23 |
+
raise ValueError(
|
| 24 |
+
"未找到任何存在的人像分割模型,请检查 hivision/creator/weights 目录下的文件"
|
| 25 |
+
+ "\n"
|
| 26 |
+
+ "No existing portrait segmentation model was found, please check the files in the hivision/creator/weights directory."
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
FACE_DETECT_MODELS = ["face++ (联网Online API)", "mtcnn"]
|
| 30 |
FACE_DETECT_MODELS_EXPAND = (
|
| 31 |
["retinaface-resnet50"]
|
|
|
|
| 36 |
)
|
| 37 |
else []
|
| 38 |
)
|
| 39 |
+
FACE_DETECT_MODELS_CHOICE = FACE_DETECT_MODELS + FACE_DETECT_MODELS_EXPAND
|
| 40 |
|
| 41 |
LANGUAGE = ["zh", "en", "ko", "ja"]
|
| 42 |
|
|
|
|
| 61 |
demo = create_ui(
|
| 62 |
processor,
|
| 63 |
root_dir,
|
| 64 |
+
HUMAN_MATTING_MODELS_CHOICE,
|
| 65 |
+
FACE_DETECT_MODELS_CHOICE,
|
| 66 |
LANGUAGE,
|
| 67 |
)
|
| 68 |
demo.launch(
|
demo/{locals.py → locales.py}
RENAMED
|
File without changes
|
demo/processor.py
CHANGED
|
@@ -16,7 +16,7 @@ from demo.utils import range_check
|
|
| 16 |
import gradio as gr
|
| 17 |
import os
|
| 18 |
import time
|
| 19 |
-
from demo.
|
| 20 |
|
| 21 |
|
| 22 |
class IDPhotoProcessor:
|
|
@@ -261,7 +261,7 @@ class IDPhotoProcessor:
|
|
| 261 |
)
|
| 262 |
|
| 263 |
# 生成排版照片
|
| 264 |
-
|
| 265 |
idphoto_json,
|
| 266 |
result_image_standard,
|
| 267 |
language,
|
|
@@ -289,7 +289,10 @@ class IDPhotoProcessor:
|
|
| 289 |
|
| 290 |
# 调整图片大小
|
| 291 |
output_image_path = self._resize_image_if_needed(
|
| 292 |
-
result_image_standard,
|
|
|
|
|
|
|
|
|
|
| 293 |
)
|
| 294 |
|
| 295 |
return self._create_response(
|
|
@@ -297,7 +300,7 @@ class IDPhotoProcessor:
|
|
| 297 |
result_image_hd,
|
| 298 |
result_image_standard_png,
|
| 299 |
result_image_hd_png,
|
| 300 |
-
|
| 301 |
output_image_path,
|
| 302 |
)
|
| 303 |
|
|
@@ -319,7 +322,7 @@ class IDPhotoProcessor:
|
|
| 319 |
|
| 320 |
return result_image_standard, result_image_hd
|
| 321 |
|
| 322 |
-
def
|
| 323 |
self,
|
| 324 |
idphoto_json,
|
| 325 |
result_image_standard,
|
|
@@ -353,14 +356,16 @@ class IDPhotoProcessor:
|
|
| 353 |
color=watermark_text_color,
|
| 354 |
)
|
| 355 |
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
|
|
|
|
|
|
| 364 |
visible=True,
|
| 365 |
)
|
| 366 |
|
|
@@ -390,31 +395,79 @@ class IDPhotoProcessor:
|
|
| 390 |
result_image_hd = add_watermark(image=result_image_hd, **watermark_params)
|
| 391 |
return result_image_standard, result_image_hd
|
| 392 |
|
| 393 |
-
def _resize_image_if_needed(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
"""如果需要,调整图片大小"""
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
resize_image_to_kb(
|
| 399 |
result_image_standard,
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
dpi=
|
| 403 |
-
idphoto_json["custom_image_dpi"]
|
| 404 |
-
if idphoto_json["custom_image_dpi"]
|
| 405 |
-
else 300
|
| 406 |
-
),
|
| 407 |
)
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
elif idphoto_json["custom_image_dpi"]:
|
| 411 |
save_image_dpi_to_bytes(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
result_image_standard,
|
| 413 |
-
|
| 414 |
-
|
|
|
|
| 415 |
)
|
| 416 |
-
return output_image_path
|
| 417 |
|
|
|
|
|
|
|
|
|
|
| 418 |
return None
|
| 419 |
|
| 420 |
def _create_response(
|
|
|
|
| 16 |
import gradio as gr
|
| 17 |
import os
|
| 18 |
import time
|
| 19 |
+
from demo.locales import LOCALES
|
| 20 |
|
| 21 |
|
| 22 |
class IDPhotoProcessor:
|
|
|
|
| 261 |
)
|
| 262 |
|
| 263 |
# 生成排版照片
|
| 264 |
+
result_image_layout, result_image_layout_gr = self._generate_image_layout(
|
| 265 |
idphoto_json,
|
| 266 |
result_image_standard,
|
| 267 |
language,
|
|
|
|
| 289 |
|
| 290 |
# 调整图片大小
|
| 291 |
output_image_path = self._resize_image_if_needed(
|
| 292 |
+
result_image_standard,
|
| 293 |
+
result_image_hd,
|
| 294 |
+
result_image_layout,
|
| 295 |
+
idphoto_json,
|
| 296 |
)
|
| 297 |
|
| 298 |
return self._create_response(
|
|
|
|
| 300 |
result_image_hd,
|
| 301 |
result_image_standard_png,
|
| 302 |
result_image_hd_png,
|
| 303 |
+
result_image_layout_gr,
|
| 304 |
output_image_path,
|
| 305 |
)
|
| 306 |
|
|
|
|
| 322 |
|
| 323 |
return result_image_standard, result_image_hd
|
| 324 |
|
| 325 |
+
def _generate_image_layout(
|
| 326 |
self,
|
| 327 |
idphoto_json,
|
| 328 |
result_image_standard,
|
|
|
|
| 356 |
color=watermark_text_color,
|
| 357 |
)
|
| 358 |
|
| 359 |
+
result_image_layout = generate_layout_image(
|
| 360 |
+
image,
|
| 361 |
+
typography_arr,
|
| 362 |
+
typography_rotate,
|
| 363 |
+
height=idphoto_json["size"][0],
|
| 364 |
+
width=idphoto_json["size"][1],
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
return result_image_layout, gr.update(
|
| 368 |
+
value=result_image_layout,
|
| 369 |
visible=True,
|
| 370 |
)
|
| 371 |
|
|
|
|
| 395 |
result_image_hd = add_watermark(image=result_image_hd, **watermark_params)
|
| 396 |
return result_image_standard, result_image_hd
|
| 397 |
|
| 398 |
+
def _resize_image_if_needed(
|
| 399 |
+
self,
|
| 400 |
+
result_image_standard,
|
| 401 |
+
result_image_hd,
|
| 402 |
+
result_image_layout,
|
| 403 |
+
idphoto_json,
|
| 404 |
+
):
|
| 405 |
"""如果需要,调整图片大小"""
|
| 406 |
+
# 设置输出路径
|
| 407 |
+
base_path = os.path.join(
|
| 408 |
+
os.path.dirname(os.path.dirname(__file__)), "demo/kb_output"
|
| 409 |
+
)
|
| 410 |
+
timestamp = int(time.time())
|
| 411 |
+
output_paths = {
|
| 412 |
+
"standard": f"{base_path}/{timestamp}_standard",
|
| 413 |
+
"hd": f"{base_path}/{timestamp}_hd",
|
| 414 |
+
"layout": f"{base_path}/{timestamp}_layout",
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
# 获取自定义的KB和DPI值
|
| 418 |
+
custom_kb = idphoto_json.get("custom_image_kb")
|
| 419 |
+
custom_dpi = idphoto_json.get("custom_image_dpi", 300)
|
| 420 |
+
|
| 421 |
+
# 处理同时有自定义KB和DPI的情况
|
| 422 |
+
if custom_kb and custom_dpi:
|
| 423 |
+
# 为所有输出路径添加DPI信息
|
| 424 |
+
for key in output_paths:
|
| 425 |
+
output_paths[key] += f"_{custom_dpi}dpi.jpg"
|
| 426 |
+
# 为标准图像添加KB信息
|
| 427 |
+
output_paths["standard"] = output_paths["standard"].replace(
|
| 428 |
+
".jpg", f"_{custom_kb}kb.jpg"
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
# 调整标准图像大小并保存
|
| 432 |
resize_image_to_kb(
|
| 433 |
result_image_standard,
|
| 434 |
+
output_paths["standard"],
|
| 435 |
+
custom_kb,
|
| 436 |
+
dpi=custom_dpi,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
)
|
| 438 |
+
# 保存高清图像和排版图像
|
| 439 |
+
save_image_dpi_to_bytes(result_image_hd, output_paths["hd"], dpi=custom_dpi)
|
|
|
|
| 440 |
save_image_dpi_to_bytes(
|
| 441 |
+
result_image_layout, output_paths["layout"], dpi=custom_dpi
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
return list(output_paths.values())
|
| 445 |
+
|
| 446 |
+
# 只有自定义DPI的情况
|
| 447 |
+
elif custom_dpi:
|
| 448 |
+
for key in output_paths:
|
| 449 |
+
output_paths[key] += f"_{custom_dpi}dpi.jpg"
|
| 450 |
+
# 保存所有图像,使用自定义DPI
|
| 451 |
+
save_image_dpi_to_bytes(
|
| 452 |
+
locals()[f"result_image_{key}"], output_paths[key], dpi=custom_dpi
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
return list(output_paths.values())
|
| 456 |
+
|
| 457 |
+
# 只有自定义KB的情况
|
| 458 |
+
elif custom_kb:
|
| 459 |
+
output_paths["standard"] += f"_{custom_kb}kb.jpg"
|
| 460 |
+
# 只调整标准图像大小并保存
|
| 461 |
+
resize_image_to_kb(
|
| 462 |
result_image_standard,
|
| 463 |
+
output_paths["standard"],
|
| 464 |
+
custom_kb,
|
| 465 |
+
dpi=300,
|
| 466 |
)
|
|
|
|
| 467 |
|
| 468 |
+
return [output_paths["standard"]]
|
| 469 |
+
|
| 470 |
+
# 如果没有自定义设置,返回None
|
| 471 |
return None
|
| 472 |
|
| 473 |
def _create_response(
|
demo/ui.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import os
|
| 3 |
import pathlib
|
| 4 |
-
from demo.
|
| 5 |
from demo.processor import IDPhotoProcessor
|
| 6 |
|
| 7 |
"""
|
|
@@ -23,7 +23,7 @@ def create_ui(
|
|
| 23 |
face_detect_models: list,
|
| 24 |
language: list,
|
| 25 |
):
|
| 26 |
-
DEFAULT_LANG =
|
| 27 |
DEFAULT_HUMAN_MATTING_MODEL = "modnet_photographic_portrait_matting"
|
| 28 |
DEFAULT_FACE_DETECT_MODEL = "retinaface-resnet50"
|
| 29 |
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import os
|
| 3 |
import pathlib
|
| 4 |
+
from demo.locales import LOCALES
|
| 5 |
from demo.processor import IDPhotoProcessor
|
| 6 |
|
| 7 |
"""
|
|
|
|
| 23 |
face_detect_models: list,
|
| 24 |
language: list,
|
| 25 |
):
|
| 26 |
+
DEFAULT_LANG = "en"
|
| 27 |
DEFAULT_HUMAN_MATTING_MODEL = "modnet_photographic_portrait_matting"
|
| 28 |
DEFAULT_FACE_DETECT_MODEL = "retinaface-resnet50"
|
| 29 |
|
demo/utils.py
CHANGED
|
@@ -1,7 +1,4 @@
|
|
| 1 |
import csv
|
| 2 |
-
import numpy as np
|
| 3 |
-
from PIL import Image
|
| 4 |
-
from hivision.plugin.watermark import Watermarker, WatermarkerStyles
|
| 5 |
|
| 6 |
|
| 7 |
def csv_to_size_list(csv_file: str) -> dict:
|
|
|
|
| 1 |
import csv
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
def csv_to_size_list(csv_file: str) -> dict:
|
hivision/creator/human_matting.py
CHANGED
|
@@ -37,10 +37,9 @@ WEIGHTS = {
|
|
| 37 |
),
|
| 38 |
}
|
| 39 |
|
| 40 |
-
ONNX_DEVICE = (
|
| 41 |
-
|
| 42 |
-
if
|
| 43 |
-
else "CPUExecutionProvider"
|
| 44 |
)
|
| 45 |
|
| 46 |
HIVISION_MODNET_SESS = None
|
|
@@ -52,7 +51,7 @@ BIREFNET_V1_LITE_SESS = None
|
|
| 52 |
def load_onnx_model(checkpoint_path, set_cpu=False):
|
| 53 |
providers = (
|
| 54 |
["CUDAExecutionProvider", "CPUExecutionProvider"]
|
| 55 |
-
if
|
| 56 |
else ["CPUExecutionProvider"]
|
| 57 |
)
|
| 58 |
|
|
@@ -365,7 +364,17 @@ def get_birefnet_portrait_matting(input_image, checkpoint_path, ref_size=512):
|
|
| 365 |
|
| 366 |
if BIREFNET_V1_LITE_SESS is None:
|
| 367 |
print("首次加载birefnet-v1-lite模型...")
|
| 368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
|
| 370 |
# 记录加载onnx模型的结束时间
|
| 371 |
load_end_time = time()
|
|
|
|
| 37 |
),
|
| 38 |
}
|
| 39 |
|
| 40 |
+
ONNX_DEVICE = onnxruntime.get_device()
|
| 41 |
+
ONNX_PROVIDER = (
|
| 42 |
+
"CUDAExecutionProvider" if ONNX_DEVICE == "GPU" else "CPUExecutionProvider"
|
|
|
|
| 43 |
)
|
| 44 |
|
| 45 |
HIVISION_MODNET_SESS = None
|
|
|
|
| 51 |
def load_onnx_model(checkpoint_path, set_cpu=False):
|
| 52 |
providers = (
|
| 53 |
["CUDAExecutionProvider", "CPUExecutionProvider"]
|
| 54 |
+
if ONNX_PROVIDER == "CUDAExecutionProvider"
|
| 55 |
else ["CPUExecutionProvider"]
|
| 56 |
)
|
| 57 |
|
|
|
|
| 364 |
|
| 365 |
if BIREFNET_V1_LITE_SESS is None:
|
| 366 |
print("首次加载birefnet-v1-lite模型...")
|
| 367 |
+
if ONNX_DEVICE == "GPU":
|
| 368 |
+
print("onnxruntime-gpu已安装,尝试使用CUDA加载模型")
|
| 369 |
+
try:
|
| 370 |
+
import torch
|
| 371 |
+
except ImportError:
|
| 372 |
+
print(
|
| 373 |
+
"torch未安装,尝试直接使用onnxruntime-gpu加载模型,这需要配置好CUDA和cuDNN"
|
| 374 |
+
)
|
| 375 |
+
BIREFNET_V1_LITE_SESS = load_onnx_model(checkpoint_path)
|
| 376 |
+
else:
|
| 377 |
+
BIREFNET_V1_LITE_SESS = load_onnx_model(checkpoint_path, set_cpu=True)
|
| 378 |
|
| 379 |
# 记录加载onnx模型的结束时间
|
| 380 |
load_end_time = time()
|