|
|
""" |
|
|
FacePoke API |
|
|
|
|
|
Author: Julian Bilcke |
|
|
Date: September 30, 2024 |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import asyncio |
|
|
from aiohttp import web, WSMsgType |
|
|
import json |
|
|
from json import JSONEncoder |
|
|
import numpy as np |
|
|
import uuid |
|
|
import logging |
|
|
import os |
|
|
import signal |
|
|
from typing import Dict, Any, List, Optional |
|
|
import base64 |
|
|
import io |
|
|
|
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
import pillow_avif |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def SIGSEGV_signal_arises(signalNum, stack): |
|
|
logger.critical(f"{signalNum} : SIGSEGV arises") |
|
|
logger.critical(f"Stack trace: {stack}") |
|
|
|
|
|
signal.signal(signal.SIGSEGV, SIGSEGV_signal_arises) |
|
|
|
|
|
from loader import initialize_models |
|
|
from engine import Engine, base64_data_uri_to_PIL_Image |
|
|
|
|
|
|
|
|
DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data') |
|
|
MODELS_DIR = os.path.join(DATA_ROOT, "models") |
|
|
|
|
|
class NumpyEncoder(json.JSONEncoder): |
|
|
def default(self, obj): |
|
|
if isinstance(obj, np.integer): |
|
|
return int(obj) |
|
|
elif isinstance(obj, np.floating): |
|
|
return float(obj) |
|
|
elif isinstance(obj, np.ndarray): |
|
|
return obj.tolist() |
|
|
else: |
|
|
return super(NumpyEncoder, self).default(obj) |
|
|
|
|
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse: |
|
|
ws = web.WebSocketResponse() |
|
|
await ws.prepare(request) |
|
|
engine = request.app['engine'] |
|
|
try: |
|
|
|
|
|
while True: |
|
|
msg = await ws.receive() |
|
|
|
|
|
if msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR): |
|
|
|
|
|
break |
|
|
|
|
|
try: |
|
|
if msg.type == WSMsgType.BINARY: |
|
|
res = await engine.load_image(msg.data) |
|
|
json_res = json.dumps(res, cls=NumpyEncoder) |
|
|
await ws.send_str(json_res) |
|
|
|
|
|
elif msg.type == WSMsgType.TEXT: |
|
|
data = json.loads(msg.data) |
|
|
webp_bytes = await engine.transform_image(data.get('uuid'), data.get('params')) |
|
|
await ws.send_bytes(webp_bytes) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in engine: {str(e)}") |
|
|
logger.exception("Full traceback:") |
|
|
await ws.send_json({"error": str(e)}) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in websocket_handler: {str(e)}") |
|
|
logger.exception("Full traceback:") |
|
|
return ws |
|
|
|
|
|
async def index(request: web.Request) -> web.Response: |
|
|
"""Serve the index.html file""" |
|
|
content = open(os.path.join(os.path.dirname(__file__), "public", "index.html"), "r").read() |
|
|
return web.Response(content_type="text/html", text=content) |
|
|
|
|
|
async def js_index(request: web.Request) -> web.Response: |
|
|
"""Serve the index.js file""" |
|
|
content = open(os.path.join(os.path.dirname(__file__), "public", "index.js"), "r").read() |
|
|
return web.Response(content_type="application/javascript", text=content) |
|
|
|
|
|
async def hf_logo(request: web.Request) -> web.Response: |
|
|
"""Serve the hf-logo.svg file""" |
|
|
content = open(os.path.join(os.path.dirname(__file__), "public", "hf-logo.svg"), "r").read() |
|
|
return web.Response(content_type="image/svg+xml", text=content) |
|
|
|
|
|
async def initialize_app() -> web.Application: |
|
|
"""Initialize and configure the web application.""" |
|
|
try: |
|
|
logger.info("Initializing application...") |
|
|
live_portrait = await initialize_models() |
|
|
|
|
|
logger.info("π Creating Engine instance...") |
|
|
engine = Engine(live_portrait=live_portrait) |
|
|
logger.info("β
Engine instance created.") |
|
|
|
|
|
app = web.Application() |
|
|
app['engine'] = engine |
|
|
|
|
|
|
|
|
app.router.add_get("/", index) |
|
|
app.router.add_get("/index.js", js_index) |
|
|
app.router.add_get("/hf-logo.svg", hf_logo) |
|
|
app.router.add_get("/ws", websocket_handler) |
|
|
|
|
|
logger.info("Application routes configured") |
|
|
|
|
|
return app |
|
|
except Exception as e: |
|
|
logger.error(f"π¨ Error during application initialization: {str(e)}") |
|
|
logger.exception("Full traceback:") |
|
|
raise |
|
|
|
|
|
if __name__ == "__main__": |
|
|
try: |
|
|
logger.info("Starting FacePoke application") |
|
|
app = asyncio.run(initialize_app()) |
|
|
logger.info("Application initialized, starting web server") |
|
|
web.run_app(app, host="0.0.0.0", port=8080) |
|
|
except Exception as e: |
|
|
logger.critical(f"π¨ FATAL: Failed to start the app: {str(e)}") |
|
|
logger.exception("Full traceback:") |
|
|
|