Files changed (1) hide show
  1. app.py +37 -7
app.py CHANGED
@@ -48,19 +48,48 @@ def read_content(file_path: str) -> str:
48
 
49
  return content
50
 
51
- def predict(img, prompt="", seed=0):
52
  img = img.convert("RGB")
53
  img = resize_with_padding(img, (512, 512))
54
- mask = remover.process(img, type='map')
55
- mask = ImageOps.invert(mask)
 
 
 
 
 
56
  with torch.autocast("cuda"):
57
  generator = torch.Generator(device='cuda').manual_seed(seed)
58
- output_controlnet = pipe(generator=generator, prompt=prompt, image=img, mask_image=mask, control_image=mask, num_images_per_prompt=1, num_inference_steps=20, guess_mode=False, controlnet_conditioning_scale=1.0, guidance_scale=7.5).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
59
  generator = torch.Generator(device='cuda').manual_seed(seed)
60
- output_sd2 = pipe(generator=generator, prompt=prompt, image=img, mask_image=mask, control_image=mask, num_images_per_prompt=1, num_inference_steps=20, guess_mode=False, controlnet_conditioning_scale=0.0, guidance_scale=7.5).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
61
  torch.cuda.empty_cache()
62
  return output_controlnet, output_sd2, mask
63
-
64
  css = '''
65
  .container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
66
  #image_upload{min-height:400px}
@@ -108,6 +137,7 @@ with image_blocks as demo:
108
  with gr.Column(variant='compact', ):
109
  image = gr.Image(value=bird_image, sources=['upload'], elem_id="image_upload", type="pil", label="Upload an image", width=512, height=512)
110
  with gr.Row(variant='compact', elem_id="prompt-container", equal_height=True):
 
111
  prompt = gr.Textbox(label='prompt', placeholder = 'What you want in the background?', show_label=True, elem_id="input-text")
112
  seed = gr.Number(label="seed", value=13)
113
  btn = gr.Button("Generate Background!")
@@ -119,7 +149,7 @@ with image_blocks as demo:
119
  mask_out = gr.Image(value=bird_mask, label="Background Mask", elem_id="output-mask", width=512, height=512)
120
  with gr.Column(variant='compact', ):
121
  sd2_out = gr.Image(value=bird_sd2, label="SD2 Output", elem_id="output-sd2", width=512, height=512)
122
- btn.click(fn=predict, inputs=[image, prompt, seed], outputs=[controlnet_out, sd2_out, mask_out ])
123
 
124
 
125
 
 
48
 
49
  return content
50
 
51
+ def predict(img, prompt="", seed=0, use_removal=True):
52
  img = img.convert("RGB")
53
  img = resize_with_padding(img, (512, 512))
54
+
55
+ if use_removal:
56
+ mask = remover.process(img, type='map')
57
+ mask = ImageOps.invert(mask)
58
+ else:
59
+ mask = Image.new("L", img.size, 0)
60
+
61
  with torch.autocast("cuda"):
62
  generator = torch.Generator(device='cuda').manual_seed(seed)
63
+ output_controlnet = pipe(
64
+ generator=generator,
65
+ prompt=prompt,
66
+ image=img,
67
+ mask_image=mask,
68
+ control_image=mask,
69
+ num_images_per_prompt=1,
70
+ num_inference_steps=20,
71
+ guess_mode=False,
72
+ controlnet_conditioning_scale=1.0,
73
+ guidance_scale=7.5
74
+ ).images[0]
75
+
76
  generator = torch.Generator(device='cuda').manual_seed(seed)
77
+ output_sd2 = pipe(
78
+ generator=generator,
79
+ prompt=prompt,
80
+ image=img,
81
+ mask_image=mask,
82
+ control_image=mask,
83
+ num_images_per_prompt=1,
84
+ num_inference_steps=20,
85
+ guess_mode=False,
86
+ controlnet_conditioning_scale=0.0,
87
+ guidance_scale=7.5
88
+ ).images[0]
89
+
90
  torch.cuda.empty_cache()
91
  return output_controlnet, output_sd2, mask
92
+
93
  css = '''
94
  .container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
95
  #image_upload{min-height:400px}
 
137
  with gr.Column(variant='compact', ):
138
  image = gr.Image(value=bird_image, sources=['upload'], elem_id="image_upload", type="pil", label="Upload an image", width=512, height=512)
139
  with gr.Row(variant='compact', elem_id="prompt-container", equal_height=True):
140
+ use_removal = gr.Checkbox(label="Use BG Remover", value=True)
141
  prompt = gr.Textbox(label='prompt', placeholder = 'What you want in the background?', show_label=True, elem_id="input-text")
142
  seed = gr.Number(label="seed", value=13)
143
  btn = gr.Button("Generate Background!")
 
149
  mask_out = gr.Image(value=bird_mask, label="Background Mask", elem_id="output-mask", width=512, height=512)
150
  with gr.Column(variant='compact', ):
151
  sd2_out = gr.Image(value=bird_sd2, label="SD2 Output", elem_id="output-sd2", width=512, height=512)
152
+ btn.click(fn=predict, inputs=[image, prompt, seed, use_removal], outputs=[controlnet_out, sd2_out, mask_out ])
153
 
154
 
155