Supports images with transparent background.

#3
by soiz1 - opened
Files changed (1) hide show
  1. app.py +44 -7
app.py CHANGED
@@ -48,16 +48,52 @@ def read_content(file_path: str) -> str:
48
 
49
  return content
50
 
51
- def predict(img, prompt="", seed=0):
52
- img = img.convert("RGB")
53
  img = resize_with_padding(img, (512, 512))
54
- mask = remover.process(img, type='map')
55
- mask = ImageOps.invert(mask)
 
 
 
 
 
 
 
 
 
 
 
 
56
  with torch.autocast("cuda"):
57
  generator = torch.Generator(device='cuda').manual_seed(seed)
58
- output_controlnet = pipe(generator=generator, prompt=prompt, image=img, mask_image=mask, control_image=mask, num_images_per_prompt=1, num_inference_steps=20, guess_mode=False, controlnet_conditioning_scale=1.0, guidance_scale=7.5).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
59
  generator = torch.Generator(device='cuda').manual_seed(seed)
60
- output_sd2 = pipe(generator=generator, prompt=prompt, image=img, mask_image=mask, control_image=mask, num_images_per_prompt=1, num_inference_steps=20, guess_mode=False, controlnet_conditioning_scale=0.0, guidance_scale=7.5).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
61
  torch.cuda.empty_cache()
62
  return output_controlnet, output_sd2, mask
63
 
@@ -108,6 +144,7 @@ with image_blocks as demo:
108
  with gr.Column(variant='compact', ):
109
  image = gr.Image(value=bird_image, sources=['upload'], elem_id="image_upload", type="pil", label="Upload an image", width=512, height=512)
110
  with gr.Row(variant='compact', elem_id="prompt-container", equal_height=True):
 
111
  prompt = gr.Textbox(label='prompt', placeholder = 'What you want in the background?', show_label=True, elem_id="input-text")
112
  seed = gr.Number(label="seed", value=13)
113
  btn = gr.Button("Generate Background!")
@@ -119,7 +156,7 @@ with image_blocks as demo:
119
  mask_out = gr.Image(value=bird_mask, label="Background Mask", elem_id="output-mask", width=512, height=512)
120
  with gr.Column(variant='compact', ):
121
  sd2_out = gr.Image(value=bird_sd2, label="SD2 Output", elem_id="output-sd2", width=512, height=512)
122
- btn.click(fn=predict, inputs=[image, prompt, seed], outputs=[controlnet_out, sd2_out, mask_out ])
123
 
124
 
125
 
 
48
 
49
  return content
50
 
51
+ def predict(img, prompt="", seed=0, use_removal=True):
52
+ img = img.convert("RGBA")
53
  img = resize_with_padding(img, (512, 512))
54
+
55
+ if use_removal:
56
+ # remover で背景を自動抽出
57
+ mask = remover.process(img.convert("RGB"), type='map')
58
+ mask = ImageOps.invert(mask)
59
+ else:
60
+ # 背景削除済みPNGのアルファチャンネルをマスクに利用
61
+ if "A" in img.getbands(): # RGBAかどうか
62
+ mask = img.getchannel("A")
63
+ mask = ImageOps.invert(mask) # 前景を黒、背景を白に
64
+ else:
65
+ # JPEGなどアルファがない場合 → 背景編集しない(全部黒)
66
+ mask = Image.new("L", img.size, 0)
67
+
68
  with torch.autocast("cuda"):
69
  generator = torch.Generator(device='cuda').manual_seed(seed)
70
+ output_controlnet = pipe(
71
+ generator=generator,
72
+ prompt=prompt,
73
+ image=img.convert("RGB"),
74
+ mask_image=mask,
75
+ control_image=mask,
76
+ num_images_per_prompt=1,
77
+ num_inference_steps=20,
78
+ guess_mode=False,
79
+ controlnet_conditioning_scale=1.0,
80
+ guidance_scale=7.5
81
+ ).images[0]
82
+
83
  generator = torch.Generator(device='cuda').manual_seed(seed)
84
+ output_sd2 = pipe(
85
+ generator=generator,
86
+ prompt=prompt,
87
+ image=img.convert("RGB"),
88
+ mask_image=mask,
89
+ control_image=mask,
90
+ num_images_per_prompt=1,
91
+ num_inference_steps=20,
92
+ guess_mode=False,
93
+ controlnet_conditioning_scale=0.0,
94
+ guidance_scale=7.5
95
+ ).images[0]
96
+
97
  torch.cuda.empty_cache()
98
  return output_controlnet, output_sd2, mask
99
 
 
144
  with gr.Column(variant='compact', ):
145
  image = gr.Image(value=bird_image, sources=['upload'], elem_id="image_upload", type="pil", label="Upload an image", width=512, height=512)
146
  with gr.Row(variant='compact', elem_id="prompt-container", equal_height=True):
147
+ use_removal = gr.Checkbox(label="Use BG Remover", value=True)
148
  prompt = gr.Textbox(label='prompt', placeholder = 'What you want in the background?', show_label=True, elem_id="input-text")
149
  seed = gr.Number(label="seed", value=13)
150
  btn = gr.Button("Generate Background!")
 
156
  mask_out = gr.Image(value=bird_mask, label="Background Mask", elem_id="output-mask", width=512, height=512)
157
  with gr.Column(variant='compact', ):
158
  sd2_out = gr.Image(value=bird_sd2, label="SD2 Output", elem_id="output-sd2", width=512, height=512)
159
+ btn.click(fn=predict, inputs=[image, prompt, seed, use_removal], outputs=[controlnet_out, sd2_out, mask_out ])
160
 
161
 
162