|
import logging |
|
import random |
|
import warnings |
|
import os |
|
import io |
|
import base64 |
|
import gradio as gr |
|
import numpy as np |
|
import spaces |
|
import torch |
|
from diffusers import FluxControlNetModel |
|
from diffusers.pipelines import FluxControlNetPipeline |
|
from gradio_imageslider import ImageSlider |
|
from PIL import Image, ImageOps |
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
warnings.filterwarnings("ignore") |
|
|
|
css = """ |
|
#col-container { |
|
margin: 0 auto; |
|
max-width: 512px; /* Increased max-width slightly for better layout */ |
|
} |
|
.gradio-container { |
|
max-width: 900px !important; /* Control overall container width */ |
|
margin: auto !important; |
|
} |
|
""" |
|
|
|
if torch.cuda.is_available(): |
|
power_device = "GPU" |
|
device = "cuda" |
|
torch_dtype = torch.bfloat16 |
|
else: |
|
power_device = "CPU" |
|
device = "cpu" |
|
torch_dtype = torch.float32 |
|
|
|
logging.info(f"Selected device: {device} | Data type: {torch_dtype}") |
|
|
|
|
|
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") |
|
|
|
|
|
flux_model_id = "black-forest-labs/FLUX.1-dev" |
|
controlnet_model_id = "jasperai/Flux.1-dev-Controlnet-Upscaler" |
|
local_model_dir = flux_model_id.split('/')[-1] |
|
pipe = None |
|
|
|
try: |
|
logging.info(f"Downloading base model: {flux_model_id}") |
|
model_path = snapshot_download( |
|
repo_id=flux_model_id, |
|
repo_type="model", |
|
ignore_patterns=["*.md", "*.gitattributes"], |
|
local_dir=local_model_dir, |
|
token=huggingface_token, |
|
) |
|
logging.info(f"Base model downloaded/verified in: {model_path}") |
|
|
|
logging.info(f"Loading ControlNet model: {controlnet_model_id}") |
|
controlnet = FluxControlNetModel.from_pretrained( |
|
controlnet_model_id, torch_dtype=torch_dtype |
|
).to(device) |
|
logging.info("ControlNet model loaded.") |
|
|
|
logging.info("Loading FluxControlNetPipeline...") |
|
pipe = FluxControlNetPipeline.from_pretrained( |
|
model_path, |
|
controlnet=controlnet, |
|
torch_dtype=torch_dtype |
|
) |
|
pipe.to(device) |
|
logging.info("Pipeline loaded and moved to device.") |
|
|
|
except Exception as e: |
|
logging.error(f"FATAL: Error during model loading: {e}", exc_info=True) |
|
|
|
print(f"FATAL ERROR DURING MODEL LOAD: {e}") |
|
raise SystemExit(f"Model loading failed: {e}") |
|
|
|
|
|
|
|
MAX_SEED = 2**32 - 1 |
|
MAX_PIXEL_BUDGET = 1280 * 1280 |
|
|
|
INTERNAL_PROCESSING_FACTOR = 4 |
|
|
|
|
|
def process_input(input_image): |
|
"""Processes the input image for the pipeline. |
|
The pixel budget check uses the fixed INTERNAL_PROCESSING_FACTOR.""" |
|
if input_image is None: |
|
raise gr.Error("Input image is missing!") |
|
try: |
|
input_image = ImageOps.exif_transpose(input_image) |
|
if input_image.mode != 'RGB': |
|
logging.info(f"Converting input image from {input_image.mode} to RGB") |
|
input_image = input_image.convert('RGB') |
|
w, h = input_image.size |
|
except AttributeError: |
|
raise gr.Error("Invalid input image format. Please provide a valid image file.") |
|
except Exception as img_err: |
|
raise gr.Error(f"Could not process input image: {img_err}") |
|
|
|
w_original, h_original = w, h |
|
if w == 0 or h == 0: |
|
raise gr.Error("Input image has zero width or height.") |
|
|
|
|
|
target_w_internal = w * INTERNAL_PROCESSING_FACTOR |
|
target_h_internal = h * INTERNAL_PROCESSING_FACTOR |
|
target_pixels_internal = target_w_internal * target_h_internal |
|
|
|
was_resized = False |
|
input_image_to_process = input_image.copy() |
|
|
|
|
|
if target_pixels_internal > MAX_PIXEL_BUDGET: |
|
max_input_pixels = MAX_PIXEL_BUDGET / (INTERNAL_PROCESSING_FACTOR**2) |
|
current_input_pixels = w * h |
|
|
|
if current_input_pixels > max_input_pixels: |
|
input_scale_factor = (max_input_pixels / current_input_pixels) ** 0.5 |
|
input_w_resized = int(w * input_scale_factor) |
|
input_h_resized = int(h * input_scale_factor) |
|
|
|
input_w_resized = max(8, input_w_resized) |
|
input_h_resized = max(8, input_h_resized) |
|
intermediate_w = input_w_resized * INTERNAL_PROCESSING_FACTOR |
|
intermediate_h = input_h_resized * INTERNAL_PROCESSING_FACTOR |
|
|
|
logging.warning( |
|
f"Requested {INTERNAL_PROCESSING_FACTOR}x intermediate output ({target_w_internal}x{target_h_internal}) exceeds budget. " |
|
f"Resizing input from {w}x{h} to {input_w_resized}x{input_h_resized}." |
|
) |
|
gr.Info( |
|
f"Intermediate {INTERNAL_PROCESSING_FACTOR}x size exceeds budget. Input resized to {input_w_resized}x{input_h_resized} " |
|
f"-> model generates ~{int(intermediate_w)}x{int(intermediate_h)}." |
|
) |
|
input_image_to_process = input_image_to_process.resize((input_w_resized, input_h_resized), Image.Resampling.LANCZOS) |
|
was_resized = True |
|
|
|
|
|
w_proc, h_proc = input_image_to_process.size |
|
w_final_proc = max(8, w_proc - w_proc % 8) |
|
h_final_proc = max(8, h_proc - h_proc % 8) |
|
|
|
if (w_proc, h_proc) != (w_final_proc, h_final_proc): |
|
logging.info(f"Rounding processed input dimensions from {w_proc}x{h_proc} to {w_final_proc}x{h_final_proc}") |
|
input_image_to_process = input_image_to_process.resize((w_final_proc, h_final_proc), Image.Resampling.LANCZOS) |
|
|
|
return input_image_to_process, w_original, h_original, was_resized |
|
|
|
|
|
@spaces.GPU(duration=75) |
|
def infer( |
|
seed, |
|
randomize_seed, |
|
input_image, |
|
num_inference_steps, |
|
final_upscale_factor, |
|
controlnet_conditioning_scale, |
|
progress=gr.Progress(track_tqdm=True), |
|
): |
|
global pipe |
|
if pipe is None: |
|
gr.Error("Pipeline not loaded. Cannot perform inference.") |
|
return [[None, None], 0, None] |
|
|
|
original_input_pil = input_image |
|
|
|
if input_image is None: |
|
gr.Warning("Please provide an input image.") |
|
return [[None, None], seed or 0, None] |
|
|
|
if randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
seed = int(seed) |
|
|
|
|
|
final_upscale_factor = int(final_upscale_factor) |
|
if final_upscale_factor > INTERNAL_PROCESSING_FACTOR: |
|
gr.Warning(f"Selected upscale factor ({final_upscale_factor}x) is larger than internal processing factor ({INTERNAL_PROCESSING_FACTOR}x). " |
|
f"Results might not be optimal. Clamping final factor to {INTERNAL_PROCESSING_FACTOR}x for this run.") |
|
final_upscale_factor = INTERNAL_PROCESSING_FACTOR |
|
|
|
logging.info( |
|
f"Starting inference with seed: {seed}, " |
|
f"Internal Processing Factor: {INTERNAL_PROCESSING_FACTOR}x, " |
|
f"Final Output Factor: {final_upscale_factor}x, " |
|
f"Steps: {num_inference_steps}, CNet Scale: {controlnet_conditioning_scale}" |
|
) |
|
|
|
try: |
|
|
|
processed_input_image, w_original, h_original, was_input_resized = process_input( |
|
input_image |
|
) |
|
except Exception as e: |
|
logging.error(f"Error processing input image: {e}", exc_info=True) |
|
gr.Error(f"Error processing input image: {e}") |
|
return [[original_input_pil, None], seed, None] |
|
|
|
w_proc, h_proc = processed_input_image.size |
|
|
|
|
|
control_image_w = w_proc * INTERNAL_PROCESSING_FACTOR |
|
control_image_h = h_proc * INTERNAL_PROCESSING_FACTOR |
|
|
|
|
|
|
|
if control_image_w * control_image_h > MAX_PIXEL_BUDGET * 1.05: |
|
scale_factor = (MAX_PIXEL_BUDGET / (control_image_w * control_image_h)) ** 0.5 |
|
control_image_w = max(8, int(control_image_w * scale_factor)) |
|
control_image_h = max(8, int(control_image_h * scale_factor)) |
|
control_image_w = max(8, control_image_w - control_image_w % 8) |
|
control_image_h = max(8, control_image_h - control_image_h % 8) |
|
logging.warning(f"Control image dimensions clamped to {control_image_w}x{control_image_h} post-processing to fit budget.") |
|
gr.Warning(f"Control image dimensions further clamped to {control_image_w}x{control_image_h}.") |
|
|
|
logging.info(f"Resizing processed input {w_proc}x{h_proc} to control image {control_image_w}x{control_image_h} (using {INTERNAL_PROCESSING_FACTOR}x factor)") |
|
try: |
|
|
|
control_image = processed_input_image.resize((control_image_w, control_image_h), Image.Resampling.LANCZOS) |
|
except ValueError as resize_err: |
|
logging.error(f"Error resizing processed input to control image: {resize_err}") |
|
gr.Error(f"Failed to prepare control image: {resize_err}") |
|
return [[original_input_pil, None], seed, None] |
|
|
|
generator = torch.Generator(device=device).manual_seed(seed) |
|
|
|
|
|
gr.Info(f"Generating intermediate image at {INTERNAL_PROCESSING_FACTOR}x quality ({control_image_w}x{control_image_h})...") |
|
logging.info(f"Running pipeline with size: {control_image_w}x{control_image_h}") |
|
intermediate_result_image = None |
|
try: |
|
with torch.inference_mode(): |
|
intermediate_result_image = pipe( |
|
prompt="", |
|
control_image=control_image, |
|
controlnet_conditioning_scale=float(controlnet_conditioning_scale), |
|
num_inference_steps=int(num_inference_steps), |
|
guidance_scale=0.0, |
|
height=control_image_h, |
|
width=control_image_w, |
|
generator=generator, |
|
).images[0] |
|
logging.info(f"Pipeline execution finished. Intermediate image size: {intermediate_result_image.size if intermediate_result_image else 'None'}") |
|
|
|
except torch.cuda.OutOfMemoryError as oom_error: |
|
logging.error(f"CUDA Out of Memory during pipeline execution: {oom_error}", exc_info=True) |
|
gr.Error(f"Ran out of GPU memory trying to generate intermediate {control_image_w}x{control_image_h}.") |
|
if device == 'cuda': torch.cuda.empty_cache() |
|
return [[original_input_pil, None], seed, None] |
|
except Exception as e: |
|
logging.error(f"Error during pipeline execution: {e}", exc_info=True) |
|
gr.Error(f"Inference failed: {e}") |
|
return [[original_input_pil, None], seed, None] |
|
|
|
if not intermediate_result_image: |
|
logging.error("Intermediate result image is None after pipeline execution.") |
|
gr.Error("Inference produced no result image.") |
|
return [[original_input_pil, None], seed, None] |
|
|
|
|
|
|
|
|
|
if was_input_resized: |
|
|
|
final_target_w = w_proc * final_upscale_factor |
|
final_target_h = h_proc * final_upscale_factor |
|
logging.warning(f"Input was downscaled. Final size based on processed input: {w_proc}x{h_proc} * {final_upscale_factor}x -> {final_target_w}x{final_target_h}") |
|
gr.Info(f"Input was downscaled. Final size target approx {final_target_w}x{final_target_h}.") |
|
else: |
|
|
|
final_target_w = w_original * final_upscale_factor |
|
final_target_h = h_original * final_upscale_factor |
|
|
|
final_result_image = intermediate_result_image |
|
current_w, current_h = intermediate_result_image.size |
|
|
|
|
|
if (current_w, current_h) != (final_target_w, final_target_h): |
|
logging.info(f"Resizing intermediate image from {current_w}x{current_h} to final target size {final_target_w}x{final_target_h} (using {final_upscale_factor}x factor)") |
|
gr.Info(f"Resizing from intermediate {current_w}x{current_h} to final {final_target_w}x{final_target_h}...") |
|
|
|
try: |
|
if final_target_w > 0 and final_target_h > 0: |
|
|
|
final_result_image = intermediate_result_image.resize((final_target_w, final_target_h), Image.Resampling.LANCZOS) |
|
else: |
|
gr.Warning(f"Invalid final target dimensions ({final_target_w}x{final_target_h}). Skipping final resize.") |
|
final_result_image = intermediate_result_image |
|
except Exception as resize_e: |
|
logging.error(f"Could not resize intermediate image to final size: {resize_e}") |
|
gr.Warning(f"Failed to resize to final {final_upscale_factor}x. Returning intermediate {INTERNAL_PROCESSING_FACTOR}x result ({current_w}x{current_h}).") |
|
final_result_image = intermediate_result_image |
|
else: |
|
logging.info(f"Intermediate size {current_w}x{current_h} matches final target size. No final resize needed.") |
|
|
|
|
|
logging.info(f"Inference successful. Final output size: {final_result_image.size}") |
|
|
|
|
|
base64_string = None |
|
if final_result_image: |
|
try: |
|
buffered = io.BytesIO() |
|
final_result_image.save(buffered, format="WEBP", quality=90) |
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
base64_string = f"data:image/webp;base64,{img_str}" |
|
logging.info(f"Encoded result image to Base64 string (length: {len(base64_string)} chars).") |
|
except Exception as enc_err: |
|
logging.error(f"Failed to encode result image to Base64: {enc_err}", exc_info=True) |
|
|
|
|
|
return [[original_input_pil, final_result_image], seed, base64_string] |
|
|
|
|
|
|
|
with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Flux Upscaler Demo") as demo: |
|
gr.Markdown( |
|
f""" |
|
# ⚡ Flux.1-dev Upscaler ControlNet ⚡ |
|
Upscale images using the [Flux.1-dev Upscaler ControlNet](https://huggingface.co/jasperai/Flux.1-dev-Controlnet-Upscaler) model based on [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev). |
|
Currently running on **{power_device}**. Hardware provided by Hugging Face 🤗. |
|
|
|
**How it works:** This demo uses an internal processing scale of **{INTERNAL_PROCESSING_FACTOR}x** for potentially higher detail generation, |
|
then resizes the result to your selected **Final Upscale Factor**. This aims for {INTERNAL_PROCESSING_FACTOR}x quality at your desired output resolution. |
|
|
|
*Note*: Intermediate processing resolution is limited to approximately **{MAX_PIXEL_BUDGET/1_000_000:.1f} megapixels** ({int(MAX_PIXEL_BUDGET**0.5)}x{int(MAX_PIXEL_BUDGET**0.5)}) due to resource constraints. |
|
The *diffusion process time* is mainly determined by this intermediate size, not the final output size. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
input_im = gr.Image( |
|
label="Input Image", |
|
type="pil", |
|
height=350, |
|
sources=["upload", "clipboard"], |
|
) |
|
with gr.Column(scale=1): |
|
|
|
upscale_factor_slider = gr.Slider(label="Final Upscale Factor", info=f"Output size relative to input. Internal processing uses {INTERNAL_PROCESSING_FACTOR}x quality.", minimum=1, maximum=INTERNAL_PROCESSING_FACTOR, step=1, value=2) |
|
num_inference_steps = gr.Slider(label="Inference Steps", minimum=4, maximum=50, step=1, value=15) |
|
controlnet_conditioning_scale = gr.Slider(label="ControlNet Conditioning Scale", info="Strength of ControlNet guidance", minimum=0.0, maximum=1.5, step=0.05, value=0.6) |
|
with gr.Row(): |
|
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42) |
|
randomize_seed = gr.Checkbox(label="Random", value=True, scale=0, min_width=80) |
|
run_button = gr.Button("⚡ Upscale Image", variant="primary", scale=1) |
|
|
|
with gr.Row(): |
|
result_slider = ImageSlider( |
|
label="Input / Output Comparison", |
|
type="pil", |
|
interactive=False, |
|
show_label=True, |
|
position=0.5 |
|
) |
|
|
|
output_seed = gr.Textbox(label="Seed Used", interactive=False, visible=True, scale=1) |
|
api_base64_output = gr.Textbox(label="API Base64 Output", interactive=False, visible=False) |
|
|
|
|
|
example_dir = "examples" |
|
example_files = ["image_2.jpg", "image_4.jpg", "low_res_face.png", "low_res_landscape.png"] |
|
example_paths = [os.path.join(example_dir, f) for f in example_files if os.path.exists(os.path.join(example_dir, f))] |
|
|
|
if example_paths: |
|
gr.Examples( |
|
|
|
examples=[ [path, 2, 15, 0.6, random.randint(0,MAX_SEED), True] for path in example_paths ], |
|
|
|
inputs=[ input_im, upscale_factor_slider, num_inference_steps, controlnet_conditioning_scale, seed, randomize_seed, ], |
|
outputs=[result_slider, output_seed], |
|
fn=infer, |
|
cache_examples="lazy", |
|
label="Example Images (Click to Run with 2x Output)", |
|
run_on_click=True |
|
) |
|
else: |
|
gr.Markdown(f"*No example images found in '{example_dir}' directory.*") |
|
|
|
gr.Markdown("---") |
|
gr.Markdown("**Disclaimer:** Demo for illustrative purposes. Users are responsible for generated content.") |
|
|
|
|
|
run_button.click( |
|
fn=infer, |
|
inputs=[ |
|
seed, |
|
randomize_seed, |
|
input_im, |
|
num_inference_steps, |
|
upscale_factor_slider, |
|
controlnet_conditioning_scale, |
|
], |
|
outputs=[result_slider, output_seed, api_base64_output], |
|
api_name="upscale" |
|
) |
|
|
|
|
|
demo.queue(max_size=10).launch(share=False, show_api=True) |