GreenGoat commited on
Commit
9397567
·
verified ·
1 Parent(s): 1fbd787

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -13
app.py CHANGED
@@ -8,7 +8,11 @@ import safetensors.torch as sf
8
  from PIL import Image
9
  from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
10
  from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
11
- from diffusers.models.attention_processor import AttnProcessor2_0
 
 
 
 
12
  from transformers import CLIPTextModel, CLIPTokenizer
13
  from enum import Enum
14
  from torch.hub import download_url_to_file
@@ -20,6 +24,7 @@ from huggingface_hub import PyTorchModelHubMixin
20
  try:
21
  from transformers import pipeline
22
  rmbg_pipeline = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
 
23
  except Exception as e:
24
  print(f"Failed to load RMBG pipeline: {e}")
25
  USE_RMBG_PIPELINE = False
@@ -46,10 +51,10 @@ print(f"Using device: {device}")
46
  print("Loading models...")
47
 
48
  # Initialize models
49
- tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer", trust_remote_code=True)
50
- text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder", trust_remote_code=True)
51
- vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae", trust_remote_code=True)
52
- unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet", trust_remote_code=True)
53
 
54
  # Modify UNet for IC-Light
55
  with torch.no_grad():
@@ -88,8 +93,14 @@ vae = vae.to(device=device, dtype=torch.bfloat16)
88
  unet = unet.to(device=device, dtype=torch.float16)
89
 
90
  # Set attention processors
91
- unet.set_attn_processor(AttnProcessor2_0())
92
- vae.set_attn_processor(AttnProcessor2_0())
 
 
 
 
 
 
93
 
94
  # Scheduler
95
  scheduler = DPMSolverMultistepScheduler(
@@ -340,10 +351,66 @@ def process(input_fg, input_bg, prompt, image_width, image_height, num_samples,
340
 
341
  @torch.inference_mode()
342
  def process_relight(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
343
- input_fg, matting = run_rmbg(input_fg)
344
- results, extra_images = process(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source)
345
- results = [(x * 255.0).clip(0, 255).astype(np.uint8) for x in results]
346
- return results + extra_images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
  # Quick prompts for easy testing
349
  quick_prompts = [
@@ -405,10 +472,16 @@ def create_demo():
405
 
406
  # Event handlers
407
  inputs = [input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source]
408
- relight_button.click(fn=process_relight, inputs=inputs, outputs=[result_gallery])
 
 
 
 
 
 
409
  example_prompts.click(lambda x: x[0], inputs=example_prompts, outputs=prompt, show_progress=False, queue=False)
410
 
411
- # Examples
412
  # gr.Examples(
413
  # examples=[
414
  # ["examples/person1.jpg", "examples/bg1.jpg", "beautiful woman, cinematic lighting", "Use Background Image"],
 
8
  from PIL import Image
9
  from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
10
  from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
11
+ try:
12
+ from diffusers.models.attention_processor import AttnProcessor2_0
13
+ except ImportError:
14
+ # Fallback for older diffusers versions
15
+ AttnProcessor2_0 = None
16
  from transformers import CLIPTextModel, CLIPTokenizer
17
  from enum import Enum
18
  from torch.hub import download_url_to_file
 
24
  try:
25
  from transformers import pipeline
26
  rmbg_pipeline = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
27
+ USE_RMBG_PIPELINE = True
28
  except Exception as e:
29
  print(f"Failed to load RMBG pipeline: {e}")
30
  USE_RMBG_PIPELINE = False
 
51
  print("Loading models...")
52
 
53
  # Initialize models
54
+ tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
55
+ text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
56
+ vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
57
+ unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
58
 
59
  # Modify UNet for IC-Light
60
  with torch.no_grad():
 
93
  unet = unet.to(device=device, dtype=torch.float16)
94
 
95
  # Set attention processors
96
+ if AttnProcessor2_0 is not None:
97
+ try:
98
+ unet.set_attn_processor(AttnProcessor2_0())
99
+ vae.set_attn_processor(AttnProcessor2_0())
100
+ except Exception as e:
101
+ print(f"Failed to set attention processors: {e}")
102
+ else:
103
+ print("AttnProcessor2_0 not available, using default processors")
104
 
105
  # Scheduler
106
  scheduler = DPMSolverMultistepScheduler(
 
351
 
352
  @torch.inference_mode()
353
  def process_relight(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
354
+ try:
355
+ # Input validation
356
+ if input_fg is None:
357
+ raise ValueError("Please upload a foreground image")
358
+ if input_bg is None and bg_source == "Use Background Image":
359
+ raise ValueError("Please upload a background image or choose a lighting direction")
360
+ if not prompt.strip():
361
+ raise ValueError("Please enter a prompt")
362
+
363
+ print(f"Processing with device: {device}")
364
+ print(f"Input shapes - FG: {input_fg.shape}, BG: {input_bg.shape if input_bg is not None else 'None'}")
365
+
366
+ # Optimize for Hugging Face free GPU (limited memory)
367
+ if device.type == 'cuda':
368
+ # Limit image size for free GPU tier
369
+ max_size = 768 # Increased for GPU but still conservative
370
+ if image_width > max_size or image_height > max_size:
371
+ scale = min(max_size / image_width, max_size / image_height)
372
+ image_width = int(image_width * scale // 64) * 64 # Keep multiple of 64
373
+ image_height = int(image_height * scale // 64) * 64
374
+ print(f"Reduced image size for GPU memory: {image_width}x{image_height}")
375
+
376
+ # Disable highres for free tier to save memory
377
+ if highres_scale > 1.0:
378
+ highres_scale = 1.0
379
+ print("Disabled highres scaling to save GPU memory")
380
+
381
+ elif device.type == 'cpu':
382
+ # Limit image size for CPU processing
383
+ max_size = 512
384
+ if image_width > max_size or image_height > max_size:
385
+ image_width = min(image_width, max_size)
386
+ image_height = min(image_height, max_size)
387
+ print(f"Reduced image size for CPU: {image_width}x{image_height}")
388
+
389
+ # Limit number of samples for CPU
390
+ if num_samples > 1:
391
+ num_samples = 1
392
+ print("Reduced num_samples to 1 for CPU processing")
393
+
394
+ print("Running background removal...")
395
+ input_fg, matting = run_rmbg(input_fg)
396
+
397
+ print("Starting main processing...")
398
+ results, extra_images = process(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source)
399
+
400
+ print("Converting results...")
401
+ results = [(x * 255.0).clip(0, 255).astype(np.uint8) for x in results]
402
+
403
+ print("Processing completed successfully!")
404
+ return results + extra_images
405
+
406
+ except Exception as e:
407
+ print(f"Error in process_relight: {str(e)}")
408
+ import traceback
409
+ traceback.print_exc()
410
+ # Return error image
411
+ error_img = np.zeros((512, 512, 3), dtype=np.uint8)
412
+ error_img[:, :] = [255, 0, 0] # Red error image
413
+ return [error_img]
414
 
415
  # Quick prompts for easy testing
416
  quick_prompts = [
 
472
 
473
  # Event handlers
474
  inputs = [input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source]
475
+ relight_button.click(
476
+ fn=process_relight,
477
+ inputs=inputs,
478
+ outputs=[result_gallery],
479
+ show_progress=True,
480
+ queue=True
481
+ )
482
  example_prompts.click(lambda x: x[0], inputs=example_prompts, outputs=prompt, show_progress=False, queue=False)
483
 
484
+ # Examples - temporarily disabled due to missing image files
485
  # gr.Examples(
486
  # examples=[
487
  # ["examples/person1.jpg", "examples/bg1.jpg", "beautiful woman, cinematic lighting", "Use Background Image"],