Spaces:
Running
on
Zero
Running
on
Zero
# --- Imports --- | |
import os | |
import sys | |
import cv2 | |
import torch | |
import gradio as gr | |
import numpy as np | |
from PIL import Image, ImageOps | |
import io | |
import base64 | |
import traceback | |
import tempfile | |
from fastapi import FastAPI, File, UploadFile | |
from fastapi.middleware.cors import CORSMiddleware | |
import spaces | |
# Import model-specific libraries | |
try: | |
from basicsr.archs.srvgg_arch import SRVGGNetCompact | |
from gfpgan.utils import GFPGANer | |
from realesrgan.utils import RealESRGANer | |
print("Successfully imported model libraries.") | |
except ImportError as e: | |
print(f"Error importing model libraries: {e}") | |
print("Please ensure basicsr, gfpgan, realesrgan are installed") | |
sys.exit(1) | |
# --- Constants --- | |
OUTPUT_DIR = 'output' | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
# --- Model Weight Downloads --- | |
MODEL_FILES = { | |
'realesr-general-x4v3.pth': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth', | |
'GFPGANv1.2.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth', | |
'GFPGANv1.3.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', | |
'GFPGANv1.4.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', | |
'RestoreFormer.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth', | |
} | |
print("Downloading model weights...") | |
for filename, url in MODEL_FILES.items(): | |
try: | |
if not os.path.exists(filename): | |
print(f"Downloading {filename}...") | |
os.system(f"wget -q {url} -P .") | |
except Exception as e: | |
print(f"Error downloading {filename}: {e}") | |
# --- Sample Image Downloads --- | |
SAMPLE_IMAGES = { | |
'lincoln.jpg': 'https://upload.wikimedia.org/wikipedia/commons/thumb/a/ab/Abraham_Lincoln_O-77_matte_collodion_print.jpg/1024px-Abraham_Lincoln_O-77_matte_collodion_print.jpg', | |
'AI-generate.jpg': 'https://user-images.githubusercontent.com/17445847/187400315-87a90ac9-d231-45d6-b377-38702bd1838f.jpg', | |
'Blake_Lively.jpg': 'https://user-images.githubusercontent.com/17445847/187400981-8a58f7a4-ef61-42d9-af80-bc6234cef860.jpg', | |
'10045.png': 'https://user-images.githubusercontent.com/17445847/187401133-8a3bf269-5b4d-4432-b2f0-6d26ee1d3307.png' | |
} | |
for filename, url in SAMPLE_IMAGES.items(): | |
try: | |
if not os.path.exists(filename): | |
torch.hub.download_url_to_file(url, filename, progress=False) | |
except Exception as e: | |
print(f"Warning: Error downloading sample image {filename}: {e}") | |
# --- Model Initialization (Background Enhancer) --- | |
upsampler = None | |
try: | |
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') | |
model_path = 'realesr-general-x4v3.pth' | |
half = torch.cuda.is_available() | |
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half) | |
except Exception as e: | |
print(f"Error creating RealESRGAN upsampler: {e}") | |
print(traceback.format_exc()) | |
print("Warning: GFPGAN will run without background enhancement.") | |
# --- Universal processing function --- | |
def process_image(input_image, version, scale): | |
""" | |
Universal image processing function that handles multiple input types. | |
Args: | |
input_image: Can be either a filepath string, PIL Image, or numpy array | |
version (str): GFPGAN model version ('v1.2', 'v1.3', 'v1.4', 'RestoreFormer') | |
scale (float): Rescaling factor for the final output relative to original | |
Returns: | |
tuple: (PIL.Image.Image | None, str | None) | |
- Output PIL image (RGB) or None on error | |
- Base64 encoded output image string (data URI) or an error message string | |
""" | |
input_pil_image = None | |
# --- Handle different input types --- | |
try: | |
# Case 1: Input is a file path string | |
if isinstance(input_image, str): | |
print(f"Loading image from filepath: {input_image}") | |
if not os.path.exists(input_image): | |
error_msg = f"Error: Input image filepath does not exist: '{input_image}'" | |
print(error_msg) | |
return None, error_msg | |
input_pil_image = Image.open(input_image) | |
# Case 2: Input is already a PIL Image | |
elif isinstance(input_image, Image.Image): | |
print("Input is already a PIL Image") | |
input_pil_image = input_image | |
# Case 3: Input is a numpy array (from OpenCV or other sources) | |
elif isinstance(input_image, np.ndarray): | |
print("Converting numpy array to PIL Image") | |
# If it's BGR (from OpenCV), convert to RGB | |
if input_image.shape[2] == 3: # Has 3 channels | |
input_pil_image = Image.fromarray(cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)) | |
else: | |
input_pil_image = Image.fromarray(input_image) | |
# Case 4: Input might be from Gradio (like a temporary file or numpy array) | |
else: | |
print(f"Unrecognized input type: {type(input_image)}") | |
# Try to handle it as a temporary file or other Gradio-provided input | |
if hasattr(input_image, "name") and os.path.exists(input_image.name): | |
input_pil_image = Image.open(input_image.name) | |
else: | |
error_msg = f"Unsupported input type: {type(input_image)}" | |
print(error_msg) | |
return None, error_msg | |
print(f"Successfully loaded image. Mode: {input_pil_image.mode}, size: {input_pil_image.size}") | |
except Exception as load_err: | |
error_msg = f"Error loading image: {load_err}" | |
print(error_msg) | |
print(traceback.format_exc()) | |
return None, error_msg | |
if input_pil_image is None: | |
return None, "Error: Failed to load input image." | |
print(f"Processing image with GFPGAN version: {version}, scale: {scale}") | |
# --- Handle EXIF Orientation --- | |
original_size_before_exif = input_pil_image.size | |
try: | |
input_pil_image = ImageOps.exif_transpose(input_pil_image) | |
if input_pil_image.size != original_size_before_exif: | |
print(f"Image size changed by EXIF transpose: {original_size_before_exif} -> {input_pil_image.size}") | |
except Exception as exif_err: | |
print(f"Warning: Could not apply EXIF transpose: {exif_err}") | |
w_orig, h_orig = input_pil_image.size | |
print(f"Input size for processing (WxH): {w_orig}x{h_orig}") | |
# Convert PIL Image to OpenCV format (BGR numpy array) | |
try: | |
img_mode = input_pil_image.mode | |
if img_mode != 'RGB': | |
print(f"Converting input image from {img_mode} to RGB") | |
input_pil_image = input_pil_image.convert('RGB') | |
img_bgr = np.array(input_pil_image)[:, :, ::-1].copy() | |
except Exception as conversion_err: | |
error_msg = f"Error converting PIL image to OpenCV format: {conversion_err}" | |
print(error_msg) | |
return None, error_msg | |
# --- Start GFPGAN Processing --- | |
try: | |
h, w = img_bgr.shape[0:2] | |
if h > 4000 or w > 4000: | |
print(f'Warning: Image size ({w}x{h}) is very large, processing might be slow or fail.') | |
model_map = { | |
'v1.2': 'GFPGANv1.2.pth', 'v1.3': 'GFPGANv1.3.pth', | |
'v1.4': 'GFPGANv1.4.pth', 'RestoreFormer': 'RestoreFormer.pth' | |
} | |
arch_map = { | |
'v1.2': 'clean', 'v1.3': 'clean', 'v1.4': 'clean', | |
'RestoreFormer': 'RestoreFormer' | |
} | |
if version not in model_map: | |
error_msg = f"Error: Unknown version selected: {version}" | |
print(error_msg) | |
return None, error_msg | |
model_path = model_map[version] | |
arch = arch_map[version] | |
if not os.path.exists(model_path): | |
error_msg = f"Error: Model file not found for version {version}: {model_path}" | |
print(error_msg) | |
return None, error_msg | |
current_bg_upsampler = upsampler | |
if not current_bg_upsampler: | |
print("Warning: RealESRGAN upsampler not available. Background enhancement disabled.") | |
face_enhancer = GFPGANer( | |
model_path=model_path, upscale=2, arch=arch, | |
channel_multiplier=2, bg_upsampler=current_bg_upsampler | |
) | |
print(f"Running GFPGAN enhancement with {version}...") | |
_, _, output_bgr = face_enhancer.enhance( | |
img_bgr, has_aligned=False, only_center_face=False, paste_back=True | |
) | |
if output_bgr is None: | |
error_msg = "Error: GFPGAN enhancement returned None." | |
print(error_msg) | |
return None, error_msg | |
print(f"Enhancement complete. Intermediate output shape (HxWxC BGR): {output_bgr.shape}") | |
# --- Post-processing (Resizing) --- | |
target_scale_factor = float(scale) | |
h_gfpgan, w_gfpgan = output_bgr.shape[0:2] | |
target_w = int(w_orig * target_scale_factor) | |
target_h = int(h_orig * target_scale_factor) | |
if target_w <= 0 or target_h <= 0: | |
print(f"Warning: Invalid target size ({target_w}x{target_h}) calculated from scale {scale}. Using GFPGAN output size {w_gfpgan}x{h_gfpgan}.") | |
target_w, target_h = w_gfpgan, h_gfpgan | |
if abs(target_w - w_gfpgan) > 2 or abs(target_h - h_gfpgan) > 2: | |
print(f"Resizing GFPGAN output ({w_gfpgan}x{h_gfpgan}) to target ({target_w}x{target_h}) based on scale {target_scale_factor}...") | |
interpolation = cv2.INTER_LANCZOS4 if (target_w * target_h) > (w_gfpgan * h_gfpgan) else cv2.INTER_AREA | |
try: | |
output_bgr = cv2.resize(output_bgr, (target_w, target_h), interpolation=interpolation) | |
except cv2.error as resize_err: | |
error_msg = f"Error during OpenCV resize: {resize_err}. Returning image before final resize attempt." | |
print(error_msg) | |
output_pil = Image.fromarray(cv2.cvtColor(output_bgr, cv2.COLOR_BGR2RGB)) | |
base64_output = None | |
try: | |
buffered = io.BytesIO() | |
output_pil.save(buffered, format="WEBP", quality=85) | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
base64_output = f"data:image/webp;base64,{img_str}" | |
except Exception as enc_err: | |
print(f"Error encoding fallback image: {enc_err}") | |
error_msg += f" | Encoding Error: {enc_err}" | |
return output_pil, base64_output if base64_output else error_msg | |
# --- Convert final result back to PIL (RGB) --- | |
output_pil = Image.fromarray(cv2.cvtColor(output_bgr, cv2.COLOR_BGR2RGB)) | |
print(f"Final output image size (WxH PIL): {output_pil.size}") | |
# --- Encode final PIL image to Base64 for API --- | |
base64_output = None | |
try: | |
buffered = io.BytesIO() | |
output_pil.save(buffered, format="WEBP", quality=90) | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
base64_output = f"data:image/webp;base64,{img_str}" | |
except Exception as enc_err: | |
error_msg = f"Error encoding final image to base64: {enc_err}" | |
print(error_msg) | |
return output_pil, error_msg | |
success_msg = f"Success! Output size: {output_pil.size[0]}x{output_pil.size[1]}" | |
return output_pil, base64_output if base64_output else success_msg | |
except Exception as error: | |
error_msg = f"Error during GFPGAN processing: {error}" | |
print(error_msg) | |
print(traceback.format_exc()) | |
error_img = None | |
try: | |
error_img = Image.new('RGB', (100, 50), color = 'red') | |
except Exception: | |
pass | |
return error_img, error_msg | |
# --- Function to handle file upload for API --- | |
def handle_file_upload(file_data): | |
"""Save uploaded file to temporary directory and return path""" | |
try: | |
print(f"Handling file upload: {type(file_data)}") | |
# Create a temporary file | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') | |
temp_path = temp_file.name | |
# If it's bytes, write directly | |
if isinstance(file_data, bytes): | |
with open(temp_path, 'wb') as f: | |
f.write(file_data) | |
# If it's a file-like object (from FastAPI/Gradio) | |
elif hasattr(file_data, 'file'): | |
content = file_data.file.read() | |
with open(temp_path, 'wb') as f: | |
f.write(content) | |
# If it's a string path, it's already saved | |
elif isinstance(file_data, str) and os.path.exists(file_data): | |
return file_data | |
else: | |
raise ValueError(f"Unsupported file data type: {type(file_data)}") | |
print(f"File saved to temporary path: {temp_path}") | |
return temp_path | |
except Exception as e: | |
print(f"Error handling file upload: {e}") | |
print(traceback.format_exc()) | |
raise | |
# --- API inference function --- | |
def inference(input_image, version, scale): | |
""" | |
API-friendly wrapper that ensures consistent behavior between web and API interfaces. | |
""" | |
try: | |
# If input is a file upload (from API), save it to a temporary path | |
if not isinstance(input_image, (str, Image.Image, np.ndarray)) and not (hasattr(input_image, 'name') and os.path.exists(input_image.name)): | |
file_path = handle_file_upload(input_image) | |
input_image = file_path | |
# Process the image | |
output_pil, base64_or_msg = process_image(input_image, version, scale) | |
# Return the processed results | |
return output_pil, base64_or_msg | |
except Exception as e: | |
print(f"Error in inference: {e}") | |
print(traceback.format_exc()) | |
# Return a placeholder error image and message | |
error_img = Image.new('RGB', (100, 50), color='red') | |
return error_img, f"Error: {str(e)}" | |
# --- Get the FastAPI app from Gradio --- | |
app = FastAPI() | |
# Add CORS middleware to allow cross-origin requests | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Allows all origins | |
allow_credentials=True, | |
allow_methods=["*"], # Allows all methods | |
allow_headers=["*"], # Allows all headers | |
) | |
# --- Direct API endpoint for file upload --- | |
async def direct_process(file: UploadFile = File(...), version: str = "v1.4", scale: float = 2.0): | |
try: | |
# Save the uploaded file | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') | |
temp_path = temp_file.name | |
with open(temp_path, 'wb') as f: | |
f.write(await file.read()) | |
# Process the image | |
_, base64_image = process_image(temp_path, version, scale) | |
# Clean up | |
os.unlink(temp_path) | |
# Return base64 image data | |
if base64_image and base64_image.startswith('data:image'): | |
return {"success": True, "image": base64_image} | |
else: | |
return {"success": False, "error": base64_image or "Unknown error"} | |
except Exception as e: | |
print(f"Error in direct-process API: {e}") | |
print(traceback.format_exc()) | |
return {"success": False, "error": str(e)} | |
# --- Gradio Interface Definition --- | |
title = "GFPGAN: Practical Face Restoration" | |
description = """Gradio demo for <a href='https://github.com/TencentARC/GFPGAN' target='_blank'><b>GFPGAN: Towards Real-World Blind Face Restoration with Generative Facial Prior</b></a>. | |
<br>Restore your <b>old photos</b> or improve <b>AI-generated faces</b>. Upload an image to start. | |
<br>If helpful, please ⭐ the <a href='https://github.com/TencentARC/GFPGAN' target='_blank'>Original Github Repo</a>. | |
<br>API endpoint available at `/predict` or `/api/direct-process`. Returns processed image and base64 data. | |
""" | |
article = "Questions? Contact the original creators (see GFPGAN repo)." | |
# Use upload component for more compatibility | |
inputs = [ | |
gr.Image(type="pil", label="Input Image", sources=["upload", "clipboard"]), | |
gr.Radio( | |
['v1.2', 'v1.3', 'v1.4', 'RestoreFormer'], | |
type="value", value='v1.4', label='GFPGAN Version', | |
info="v1.4 recommended. RestoreFormer for diverse poses." | |
), | |
gr.Number( | |
label="Rescaling Factor", value=2, | |
info="Final output size multiplier relative to original input size (e.g., 2 = 2x original WxH)." | |
), | |
] | |
outputs = [ | |
gr.Image(type="pil", label="Output Image"), | |
gr.Textbox(label="Output Info / Base64 Data", interactive=False, visible=True) | |
] | |
examples = [ | |
['AI-generate.jpg', 'v1.4', 2], | |
['lincoln.jpg', 'v1.4', 2], | |
['Blake_Lively.jpg', 'v1.4', 2], | |
['10045.png', 'v1.4', 2] | |
] | |
# --- Gradio Interface Instantiation --- | |
demo = gr.Interface( | |
fn=inference, | |
inputs=inputs, | |
outputs=outputs, | |
title=title, | |
description=description, | |
article=article, | |
examples=examples, | |
cache_examples=False, | |
allow_flagging='never' | |
) | |
# Mount the Gradio app | |
app = gr.mount_gradio_app(app, demo, path="/") | |
# Launch the interface | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |