Ash2505 commited on
Commit
d110ed4
·
verified ·
1 Parent(s): df76deb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -4
app.py CHANGED
@@ -1,7 +1,197 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from PIL import Image, ImageFilter
3
+ import matplotlib.pyplot as plt
4
+ import torch
5
+ import cv2
6
+ import numpy as np
7
+ from torchvision import transforms
8
+ from transformers import AutoModelForImageSegmentation, DepthProImageProcessorFast, DepthProForDepthEstimation
9
+ import requests
10
 
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
12
 
13
+ birefnet = AutoModelForImageSegmentation.from_pretrained('ZhengPeng7/BiRefNet', trust_remote_code=True)
14
+ torch.set_float32_matmul_precision(['high', 'highest'][0])
15
+ birefnet.to('cuda')
16
+ birefnet.eval()
17
+ birefnet.half()
18
+
19
+ def extract_object(image, t1, t2):
20
+ # Data settings
21
+ image_size = (1024, 1024)
22
+ transform_image = transforms.Compose([
23
+ transforms.Resize(image_size),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
26
+ ])
27
+
28
+ # image = Image.open(imagepath)
29
+ image1 = image.copy()
30
+ input_images = transform_image(image1).unsqueeze(0).to('cuda').half()
31
+
32
+ # Prediction
33
+ with torch.no_grad():
34
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
35
+ pred = preds[0].squeeze()
36
+ pred_pil = transforms.ToPILImage()(pred)
37
+ mask = pred_pil.resize(image1.size)
38
+ image1.putalpha(mask)
39
+
40
+ blurredBg = cv2.GaussianBlur(np.array(imageResized), (0, 0), sigmaX=15, sigmaY=15)
41
+
42
+ mask = np.array(result[1].convert("L"))
43
+ _, maskBinary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
44
+ img = cv2.cvtColor(np.array(imageResized), cv2.COLOR_RGB2BGR)
45
+
46
+ maskInv = cv2.bitwise_not(maskBinary)
47
+ maskInv3 = cv2.cvtColor(maskInv, cv2.COLOR_GRAY2BGR)
48
+
49
+ foreground = cv2.bitwise_and(img, cv2.bitwise_not(maskInv3))
50
+ background = cv2.bitwise_and(blurredBg, maskInv3)
51
+ finalImg = cv2.add(cv2.cvtColor(foreground, cv2.COLOR_BGR2RGB), background)
52
+
53
+ # plt.figure(figsize=(15, 5))
54
+ # return image1, mask
55
+
56
+ # def depth_estimation():
57
+ imageProcessor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf")
58
+ model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device)
59
+
60
+ inputs = imageProcessor(images=imageResized, return_tensors="pt").to(device)
61
+
62
+ with torch.no_grad():
63
+ outputs = model(**inputs)
64
+
65
+ post_processed_output = imageProcessor.post_process_depth_estimation(
66
+ outputs, target_sizes=[(imageResized.height, imageResized.width)],
67
+ )
68
+
69
+ field_of_view = post_processed_output[0]["field_of_view"]
70
+ focal_length = post_processed_output[0]["focal_length"]
71
+ depth = post_processed_output[0]["predicted_depth"]
72
+ depth = (depth - depth.min()) / (depth.max() - depth.min())
73
+ depth = depth * 255.
74
+ depth = depth.detach().cpu().numpy()
75
+ # print(depth)
76
+ depthImg = Image.fromarray(depth.astype("uint8"))
77
+
78
+ # threshold1 = 255 / 20 # ~85
79
+ # threshold2 = 2 * 255 / 3 # ~170
80
+
81
+ threshold1 = (t1/10) * 255
82
+ threshold2 = (t2/10) * 255
83
+
84
+ # Precompute blurred versions for each region
85
+ img_foreground = img.copy() # No blur for foreground
86
+ img_middleground = cv2.GaussianBlur(img, (0, 0), sigmaX=7, sigmaY=7)
87
+ img_background = cv2.GaussianBlur(img, (0, 0), sigmaX=15, sigmaY=15)
88
+
89
+ # Create masks for each region (as float arrays for proper blending)
90
+ mask_fg = (depth < threshold1).astype(np.float32)
91
+ mask_mg = ((depth >= threshold1) & (depth < threshold2)).astype(np.float32)
92
+ mask_bg = (depth >= threshold2).astype(np.float32)
93
+
94
+ # Expand masks to 3 channels (H, W, 3)
95
+ mask_fg = np.stack([mask_fg]*3, axis=-1)
96
+ mask_mg = np.stack([mask_mg]*3, axis=-1)
97
+ mask_bg = np.stack([mask_bg]*3, axis=-1)
98
+
99
+ # Combine the images using the masks in a vectorized manner.
100
+ final_img = (img_foreground * mask_fg +
101
+ img_middleground * mask_mg +
102
+ img_background * mask_bg).astype(np.uint8)
103
+
104
+ # Convert the result back to RGB for display with matplotlib.
105
+ final_img_rgb = cv2.cvtColor(final_img, cv2.COLOR_BGR2RGB)
106
+
107
+ return image1, final_img
108
+
109
+ # Visualization
110
+ # plt.axis("off")
111
+ # subplots for 3 images: original, segmented, mask
112
+
113
+ # plt.figure(figsize=(15, 5))
114
+
115
+ # image = Image.open('/content/drive/MyDrive/eee515-hw3/hw3-q24.jpg')
116
+ # #resize the image to 512x512
117
+ # imageResized = image.resize((512, 512))
118
+
119
+ # result = extract_object(birefnet, imageResized)
120
+ # plt.subplot(1, 3, 1)
121
+ # plt.title("Original Resized Image")
122
+ # plt.imshow(imageResized)
123
+
124
+ # plt.subplot(1, 3, 2)
125
+ # plt.title("Segmented Image")
126
+ # plt.imshow(result[0])
127
+
128
+ # plt.subplot(1, 3, 3)
129
+ # plt.title("Mask")
130
+ # plt.imshow(result[1], cmap="gray")
131
+ # plt.show()
132
+
133
+ # Create a Gradio interface
134
+
135
+
136
+ def build_interface(image1, image2):
137
+ """Build UI for gradio app
138
+ """
139
+ title = "Bokeh and Lens Blur"
140
+ with gr.Blocks(theme=gr.themes.Soft(), title=title, fill_width=True) as interface:
141
+ with gr.Row():
142
+ # with gr.Column(scale=3):
143
+ # with gr.Group():
144
+ # input_text_box = gr.Textbox(
145
+ # value=None,
146
+ # label="Prompt",
147
+ # lines=2,
148
+ # )
149
+ # # gr.Markdown("### Set the values for Middleground and Background")
150
+ # # fg = gr.Slider(minimum=0, maximum=99, step=1, value=33, label="Middleground")
151
+ # # mg = gr.Slider(minimum=0, maximum=99, step=1, value=66, label="Background")
152
+ # with gr.Row():
153
+ # submit_button = gr.Button("Submit", variant="primary")
154
+ with gr.Column(scale=3):
155
+ model3d = gr.Model3D(
156
+ label="Output", height="45em", interactive=False
157
+ )
158
+
159
+ with gr.Column(scale=3):
160
+ model3d = gr.Model3D(
161
+ label="Output", height="45em", interactive=False
162
+ )
163
+
164
+ submit_button.click(
165
+ handle_text_prompt,
166
+ inputs=[
167
+ input_text_box,
168
+ variance
169
+ ],
170
+ outputs=[
171
+ model3d
172
+ ]
173
+ )
174
+
175
+ return interface
176
+
177
+ # demo = gr.Interface(sepia, gr.Image(), "image")
178
+
179
+ title = "Gaussian Blur Background App"
180
+ description = (
181
+ "Upload an image to apply a realistic background blur effect. "
182
+ "The app segments the foreground using RMBG-2.0 and then applies a Gaussian "
183
+ "blur (σ=15) to the background, simulating a video conferencing blur effect."
184
+ )
185
+
186
+ iface = gr.Interface(
187
+ fn=apply_blur_effect,
188
+ inputs=[gr.Image(type="pil", label="Input Image"), gr.Slider(minimum=0, maximum=40, step=1, value=33, label="Middleground"), gr.Slider(minimum=40, maximum=99, step=1, value=66, label="Background")],
189
+ outputs=[gr.Image(type="pil", label="Bokeh Image", gr.Image(type="pil", label="Lens Blur Image"))],
190
+ title=title,
191
+ description=description,
192
+ allow_flagging="never"
193
+ )
194
+
195
+ demo = build_interface()
196
+ demo.queue(default_concurrency_limit=1)
197
+ demo.launch()