Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -8,7 +8,11 @@ import safetensors.torch as sf
|
|
8 |
from PIL import Image
|
9 |
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
|
10 |
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
|
11 |
-
|
|
|
|
|
|
|
|
|
12 |
from transformers import CLIPTextModel, CLIPTokenizer
|
13 |
from enum import Enum
|
14 |
from torch.hub import download_url_to_file
|
@@ -20,6 +24,7 @@ from huggingface_hub import PyTorchModelHubMixin
|
|
20 |
try:
|
21 |
from transformers import pipeline
|
22 |
rmbg_pipeline = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
|
|
|
23 |
except Exception as e:
|
24 |
print(f"Failed to load RMBG pipeline: {e}")
|
25 |
USE_RMBG_PIPELINE = False
|
@@ -46,10 +51,10 @@ print(f"Using device: {device}")
|
|
46 |
print("Loading models...")
|
47 |
|
48 |
# Initialize models
|
49 |
-
tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer"
|
50 |
-
text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder"
|
51 |
-
vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae"
|
52 |
-
unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet"
|
53 |
|
54 |
# Modify UNet for IC-Light
|
55 |
with torch.no_grad():
|
@@ -88,8 +93,14 @@ vae = vae.to(device=device, dtype=torch.bfloat16)
|
|
88 |
unet = unet.to(device=device, dtype=torch.float16)
|
89 |
|
90 |
# Set attention processors
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
# Scheduler
|
95 |
scheduler = DPMSolverMultistepScheduler(
|
@@ -340,10 +351,66 @@ def process(input_fg, input_bg, prompt, image_width, image_height, num_samples,
|
|
340 |
|
341 |
@torch.inference_mode()
|
342 |
def process_relight(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
|
348 |
# Quick prompts for easy testing
|
349 |
quick_prompts = [
|
@@ -405,10 +472,16 @@ def create_demo():
|
|
405 |
|
406 |
# Event handlers
|
407 |
inputs = [input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source]
|
408 |
-
relight_button.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
example_prompts.click(lambda x: x[0], inputs=example_prompts, outputs=prompt, show_progress=False, queue=False)
|
410 |
|
411 |
-
# Examples
|
412 |
# gr.Examples(
|
413 |
# examples=[
|
414 |
# ["examples/person1.jpg", "examples/bg1.jpg", "beautiful woman, cinematic lighting", "Use Background Image"],
|
|
|
8 |
from PIL import Image
|
9 |
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
|
10 |
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
|
11 |
+
try:
|
12 |
+
from diffusers.models.attention_processor import AttnProcessor2_0
|
13 |
+
except ImportError:
|
14 |
+
# Fallback for older diffusers versions
|
15 |
+
AttnProcessor2_0 = None
|
16 |
from transformers import CLIPTextModel, CLIPTokenizer
|
17 |
from enum import Enum
|
18 |
from torch.hub import download_url_to_file
|
|
|
24 |
try:
|
25 |
from transformers import pipeline
|
26 |
rmbg_pipeline = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
|
27 |
+
USE_RMBG_PIPELINE = True
|
28 |
except Exception as e:
|
29 |
print(f"Failed to load RMBG pipeline: {e}")
|
30 |
USE_RMBG_PIPELINE = False
|
|
|
51 |
print("Loading models...")
|
52 |
|
53 |
# Initialize models
|
54 |
+
tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
|
55 |
+
text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
|
56 |
+
vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
|
57 |
+
unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
|
58 |
|
59 |
# Modify UNet for IC-Light
|
60 |
with torch.no_grad():
|
|
|
93 |
unet = unet.to(device=device, dtype=torch.float16)
|
94 |
|
95 |
# Set attention processors
|
96 |
+
if AttnProcessor2_0 is not None:
|
97 |
+
try:
|
98 |
+
unet.set_attn_processor(AttnProcessor2_0())
|
99 |
+
vae.set_attn_processor(AttnProcessor2_0())
|
100 |
+
except Exception as e:
|
101 |
+
print(f"Failed to set attention processors: {e}")
|
102 |
+
else:
|
103 |
+
print("AttnProcessor2_0 not available, using default processors")
|
104 |
|
105 |
# Scheduler
|
106 |
scheduler = DPMSolverMultistepScheduler(
|
|
|
351 |
|
352 |
@torch.inference_mode()
|
353 |
def process_relight(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
|
354 |
+
try:
|
355 |
+
# Input validation
|
356 |
+
if input_fg is None:
|
357 |
+
raise ValueError("Please upload a foreground image")
|
358 |
+
if input_bg is None and bg_source == "Use Background Image":
|
359 |
+
raise ValueError("Please upload a background image or choose a lighting direction")
|
360 |
+
if not prompt.strip():
|
361 |
+
raise ValueError("Please enter a prompt")
|
362 |
+
|
363 |
+
print(f"Processing with device: {device}")
|
364 |
+
print(f"Input shapes - FG: {input_fg.shape}, BG: {input_bg.shape if input_bg is not None else 'None'}")
|
365 |
+
|
366 |
+
# Optimize for Hugging Face free GPU (limited memory)
|
367 |
+
if device.type == 'cuda':
|
368 |
+
# Limit image size for free GPU tier
|
369 |
+
max_size = 768 # Increased for GPU but still conservative
|
370 |
+
if image_width > max_size or image_height > max_size:
|
371 |
+
scale = min(max_size / image_width, max_size / image_height)
|
372 |
+
image_width = int(image_width * scale // 64) * 64 # Keep multiple of 64
|
373 |
+
image_height = int(image_height * scale // 64) * 64
|
374 |
+
print(f"Reduced image size for GPU memory: {image_width}x{image_height}")
|
375 |
+
|
376 |
+
# Disable highres for free tier to save memory
|
377 |
+
if highres_scale > 1.0:
|
378 |
+
highres_scale = 1.0
|
379 |
+
print("Disabled highres scaling to save GPU memory")
|
380 |
+
|
381 |
+
elif device.type == 'cpu':
|
382 |
+
# Limit image size for CPU processing
|
383 |
+
max_size = 512
|
384 |
+
if image_width > max_size or image_height > max_size:
|
385 |
+
image_width = min(image_width, max_size)
|
386 |
+
image_height = min(image_height, max_size)
|
387 |
+
print(f"Reduced image size for CPU: {image_width}x{image_height}")
|
388 |
+
|
389 |
+
# Limit number of samples for CPU
|
390 |
+
if num_samples > 1:
|
391 |
+
num_samples = 1
|
392 |
+
print("Reduced num_samples to 1 for CPU processing")
|
393 |
+
|
394 |
+
print("Running background removal...")
|
395 |
+
input_fg, matting = run_rmbg(input_fg)
|
396 |
+
|
397 |
+
print("Starting main processing...")
|
398 |
+
results, extra_images = process(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source)
|
399 |
+
|
400 |
+
print("Converting results...")
|
401 |
+
results = [(x * 255.0).clip(0, 255).astype(np.uint8) for x in results]
|
402 |
+
|
403 |
+
print("Processing completed successfully!")
|
404 |
+
return results + extra_images
|
405 |
+
|
406 |
+
except Exception as e:
|
407 |
+
print(f"Error in process_relight: {str(e)}")
|
408 |
+
import traceback
|
409 |
+
traceback.print_exc()
|
410 |
+
# Return error image
|
411 |
+
error_img = np.zeros((512, 512, 3), dtype=np.uint8)
|
412 |
+
error_img[:, :] = [255, 0, 0] # Red error image
|
413 |
+
return [error_img]
|
414 |
|
415 |
# Quick prompts for easy testing
|
416 |
quick_prompts = [
|
|
|
472 |
|
473 |
# Event handlers
|
474 |
inputs = [input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source]
|
475 |
+
relight_button.click(
|
476 |
+
fn=process_relight,
|
477 |
+
inputs=inputs,
|
478 |
+
outputs=[result_gallery],
|
479 |
+
show_progress=True,
|
480 |
+
queue=True
|
481 |
+
)
|
482 |
example_prompts.click(lambda x: x[0], inputs=example_prompts, outputs=prompt, show_progress=False, queue=False)
|
483 |
|
484 |
+
# Examples - temporarily disabled due to missing image files
|
485 |
# gr.Examples(
|
486 |
# examples=[
|
487 |
# ["examples/person1.jpg", "examples/bg1.jpg", "beautiful woman, cinematic lighting", "Use Background Image"],
|