face / app.py
kimhyunwoo's picture
Update app.py
9106952 verified
raw
history blame
6.42 kB
from fastapi import FastAPI
import gradio as gr
from PIL import Image
import numpy as np
import torch
from transformers import pipeline
import cv2 # cv2๋Š” ์ด๋ฏธ์ง€ ๋กœ๋“œ/์ €์žฅ/์ „์ฒ˜๋ฆฌ ๋“ฑ์— ์‚ฌ์šฉ๋  ์ˆ˜ ์žˆ์ง€๋งŒ, ํ˜„์žฌ ์ฝ”๋“œ์—์„œ๋Š” PIL๋งŒ์œผ๋กœ๋„ ์ถฉ๋ถ„ํ•ฉ๋‹ˆ๋‹ค.
app = FastAPI()
# ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ ๋กœ๋“œ (Depth Anything)
# ๋ชจ๋ธ ๋กœ๋”ฉ์€ ์•ฑ ์‹œ์ž‘ ์‹œ ํ•œ ๋ฒˆ๋งŒ ํ•˜๋„๋ก ๊ธ€๋กœ๋ฒŒ ๋ณ€์ˆ˜๋กœ ์„ค์ •
print("Loading Depth Anything model...")
try:
# ๐ŸŒŸ๐ŸŒŸ๐ŸŒŸ ๋ชจ๋ธ ์ด๋ฆ„ ์ˆ˜์ •: ์ •ํ™•ํ•œ Depth Anything v2 Large ๋ชจ๋ธ ID ๐ŸŒŸ๐ŸŒŸ๐ŸŒŸ
# ๋” ์ž‘์€ ๋ชจ๋ธ์„ ์›ํ•˜์‹œ๋ฉด "LiangNX/depth-anything-v2-base-nyu" ๋กœ ๋ณ€๊ฒฝํ•˜์„ธ์š”.
depth_estimator = pipeline(task="depth-estimation", model="LiangNX/depth-anything-v2-large-nyu", device="cpu") # GPU ์‚ฌ์šฉ ์‹œ device="cuda"
print("Depth Anything model loaded successfully.")
except Exception as e:
print(f"Error loading Depth Anything model: {e}")
depth_estimator = None # ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ ์‹œ None์œผ๋กœ ์„ค์ •
def process_image_for_depth(image_path_or_pil_image):
"""
๋‹จ์ผ ์ด๋ฏธ์ง€์—์„œ ๋ށ์Šค ๋งต์„ ์ถ”์ถœํ•˜๋Š” ํ•จ์ˆ˜.
"""
if depth_estimator is None:
return None, "Error: Depth Anything model not loaded. Check server logs."
# Gradio๋Š” PIL Image ๊ฐ์ฒด๋กœ ์ด๋ฏธ์ง€๋ฅผ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
if isinstance(image_path_or_pil_image, str):
# ํŒŒ์ผ ๊ฒฝ๋กœ๋กœ ์ด๋ฏธ์ง€๊ฐ€ ์ „๋‹ฌ๋œ ๊ฒฝ์šฐ (์˜ˆ: Gradio์˜ `filepath` ํƒ€์ž…)
image = Image.open(image_path_or_pil_image).convert("RGB")
else:
# PIL Image ๊ฐ์ฒด๊ฐ€ ์ง์ ‘ ์ „๋‹ฌ๋œ ๊ฒฝ์šฐ
image = image_path_or_pil_image.convert("RGB")
try:
# Depth Anything ๋ชจ๋ธ ์ถ”๋ก 
# result๋Š” ๋”•์…”๋„ˆ๋ฆฌ๋กœ, 'depth' (PIL Image)์™€ 'depth_npy' (numpy array)๋ฅผ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.
result = depth_estimator(image)
# ๋ށ์Šค ๋งต (PIL Image) - ์ด PIL ์ด๋ฏธ์ง€๋Š” ์ด๋ฏธ ์‹œ๊ฐํ™”ํ•˜๊ธฐ ์ข‹์€ ํ˜•ํƒœ๋กœ ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
depth_image_pil = result["depth"]
# ๋ށ์Šค ๋งต (Numpy Array) - ํ•„์š”ํ•˜๋‹ค๋ฉด ์ถ”๊ฐ€ ์ฒ˜๋ฆฌ (์˜ˆ: ์ •๊ทœํ™”, ๋‹ค๋ฅธ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜)
# ํ˜„์žฌ๋Š” ํ‘๋ฐฑ ์ด๋ฏธ์ง€๋กœ ๋ฐ”๋กœ ์‚ฌ์šฉํ•ด๋„ ์ถฉ๋ถ„ํ•˜๋ฏ€๋กœ, ์•„๋ž˜ ์ฝ”๋“œ๋Š” ์„ ํƒ ์‚ฌํ•ญ์ž…๋‹ˆ๋‹ค.
# depth_np = result["depth_npy"]
# normalized_depth_np = (depth_np - depth_np.min()) / (depth_np.max() - depth_np.min()) * 255
# normalized_depth_np = normalized_depth_np.astype(np.uint8)
# depth_grayscale_pil = Image.fromarray(normalized_depth_np)
return depth_image_pil, None # ๋ށ์Šค ๋งต PIL Image ๋ฐ˜ํ™˜
except Exception as e:
import traceback
traceback.print_exc() # ์—๋Ÿฌ ๋ฐœ์ƒ ์‹œ ์ƒ์„ธ ํŠธ๋ ˆ์ด์Šค๋ฐฑ ์ถœ๋ ฅ
return None, f"Error processing image for depth: {e}"
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ •์˜
with gr.Blocks() as demo:
gr.Markdown("# ๐Ÿง‘โ€๐Ÿ’ป ์–ผ๊ตด ๋ށ์Šค ๋งต ์ถ”์ถœ๊ธฐ")
gr.Markdown("์—ฌ๋Ÿฌ ์žฅ์˜ ์–ผ๊ตด ์‚ฌ์ง„์„ ์—…๋กœ๋“œํ•˜๋ฉด ๊ฐ ์‚ฌ์ง„์—์„œ ๋”ฅ๋Ÿฌ๋‹์„ ํ†ตํ•ด ๋ށ์Šค ๋งต(๊นŠ์ด ์ •๋ณด)์„ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค.")
gr.Markdown("โš ๏ธ **์ฐธ๊ณ :** ๋ฌด๋ฃŒ CPU ํ™˜๊ฒฝ์—์„œ๋Š” ๋ชจ๋ธ ๋กœ๋”ฉ ๋ฐ ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ์‹œ๊ฐ„์ด ๋งค์šฐ ์˜ค๋ž˜ ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
with gr.Row():
# file_count="multiple"๋กœ ์—ฌ๋Ÿฌ ํŒŒ์ผ ์—…๋กœ๋“œ ๊ฐ€๋Šฅํ•˜๊ฒŒ ์„ค์ •
# type="filepath"๋กœ ์„ค์ •ํ•˜์—ฌ Gradio๊ฐ€ ํŒŒ์ผ์„ ์ž„์‹œ ๊ฒฝ๋กœ์— ์ €์žฅํ•˜๋„๋ก ํ•จ
input_images = gr.File(label="์–ผ๊ตด ์‚ฌ์ง„ ์—…๋กœ๋“œ (์ตœ๋Œ€ 10์žฅ ๊ถŒ์žฅ)", file_count="multiple", type="filepath")
output_gallery = gr.Gallery(label="์›๋ณธ ์ด๋ฏธ์ง€ ๋ฐ ๋ށ์Šค ๋งต", columns=[2], rows=[1], object_fit="contain", height="auto")
process_button = gr.Button("๋ށ์Šค ๋งต ์ถ”์ถœ ์‹œ์ž‘")
def process_all_images(image_paths):
"""
์—…๋กœ๋“œ๋œ ๋ชจ๋“  ์ด๋ฏธ์ง€์— ๋Œ€ํ•ด ๋ށ์Šค ๋งต ์ถ”์ถœ์„ ์ˆ˜ํ–‰ํ•˜๊ณ  ๊ฒฐ๊ณผ๋ฅผ ๋ฐ˜ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜.
"""
if not image_paths:
return [(None, "์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”.")]
results_for_gallery = [] # Gradio Gallery์— ํ‘œ์‹œํ•  (PIL Image, Label) ํŠœํ”Œ ๋ฆฌ์ŠคํŠธ
for i, path in enumerate(image_paths):
try:
original_image = Image.open(path).convert("RGB")
depth_map_pil, error = process_image_for_depth(original_image)
# ์›๋ณธ ์ด๋ฏธ์ง€๋ฅผ ๋จผ์ € ์ถ”๊ฐ€
results_for_gallery.append((original_image, f"์›๋ณธ ์ด๋ฏธ์ง€ {i+1}"))
if error:
print(f"Error processing image {i+1}: {error}")
# ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•œ ๊ฒฝ์šฐ ๋ށ์Šค ๋งต ๋Œ€์‹  ์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€ ํ‘œ์‹œ
# ์ด ๋ถ€๋ถ„์€ Gradio Gallery๊ฐ€ (Image, Label)์„ ๊ธฐ๋Œ€ํ•˜๋ฏ€๋กœ,
# ์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€๋ฅผ ํ…์ŠคํŠธ ์ด๋ฏธ์ง€๋กœ ๋งŒ๋“ค๊ฑฐ๋‚˜, ์˜ค๋ฅ˜ ์ด๋ฏธ์ง€๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
# ์—ฌ๊ธฐ์„œ๋Š” ๊ฐ„๋‹จํžˆ ๋นˆ ์ด๋ฏธ์ง€์™€ ํ•จ๊ป˜ ์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€๋ฅผ ํ‘œ์‹œํ•˜๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.
dummy_error_image = Image.new('RGB', original_image.size, color = 'red')
results_for_gallery.append((dummy_error_image, f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {error}"))
else:
results_for_gallery.append((depth_map_pil, f"๋ށ์Šค ๋งต {i+1}"))
except Exception as e:
# ํŒŒ์ผ ์ž์ฒด๋ฅผ ์—ฌ๋Š” ๋ฐ ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ•œ ๊ฒฝ์šฐ
print(f"Failed to open or process file {path}: {e}")
dummy_error_image = Image.new('RGB', (200, 200), color = 'red') # ์ž‘์€ ์—๋Ÿฌ ์ด๋ฏธ์ง€
results_for_gallery.append((dummy_error_image, f"ํŒŒ์ผ ์ฒ˜๋ฆฌ ์˜ค๋ฅ˜: {e}"))
return results_for_gallery
# ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ํ•จ์ˆ˜ ์—ฐ๊ฒฐ
process_button.click(
fn=process_all_images,
inputs=input_images,
outputs=output_gallery
)
# Gradio ์•ฑ์„ FastAPI์— ๋งˆ์šดํŠธ
# ๊ธฐ๋ณธ ๊ฒฝ๋กœ (path="/")์— Gradio ์•ฑ์„ ๋งˆ์šดํŠธํ•˜์—ฌ ์›น ํŽ˜์ด์ง€ ์ ‘์† ์‹œ ๋ฐ”๋กœ UI๊ฐ€ ๋ณด์ด๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.
app = gr.mount_gradio_app(app, demo, path="/")
# FastAPI ๊ธฐ๋ณธ ์—”๋“œํฌ์ธํŠธ (์„ ํƒ ์‚ฌํ•ญ, Gradio ์•ฑ์ด ๊ธฐ๋ณธ ๊ฒฝ๋กœ๋ฅผ ์ ์œ ํ•จ)
# ์ด ์—”๋“œํฌ์ธํŠธ๋Š” /api ๊ฒฝ๋กœ๋กœ ์ ‘์†ํ–ˆ์„ ๋•Œ๋งŒ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค.
@app.get("/api")
def read_root():
return {"message": "Welcome to the Face Depth Map Extractor! Visit / for the UI."}