Update app.py
Browse files
app.py
CHANGED
@@ -159,11 +159,24 @@ def infer_inp(prompt, audio_path, spec_with_mask, progress=gr.Progress(track_tqd
|
|
159 |
raw_image = image_add_color(torch_to_pil(norm_spec))
|
160 |
|
161 |
# Add Mask
|
162 |
-
mask = torch.zeros_like(norm_spec)[:1,...]
|
163 |
-
mask[:, :, width_start:width_start+width] = 1
|
164 |
-
mask_image = torch_to_pil(mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
-
mask, masked_spec = prepare_mask_and_masked_image(norm_spec,
|
167 |
masked_spec_image = torch_to_pil(masked_spec)
|
168 |
|
169 |
# color masked spec and paint masked area to black
|
|
|
159 |
raw_image = image_add_color(torch_to_pil(norm_spec))
|
160 |
|
161 |
# Add Mask
|
162 |
+
#mask = torch.zeros_like(norm_spec)[:1,...]
|
163 |
+
#mask[:, :, width_start:width_start+width] = 1
|
164 |
+
#mask_image = torch_to_pil(mask)
|
165 |
+
|
166 |
+
# Load the mask image (input from user)
|
167 |
+
mask_pil = spec_with_mask['layers'][0]
|
168 |
+
|
169 |
+
# Convert to tensor and normalize
|
170 |
+
mask_tensor = transforms.ToTensor()(mask_pil) # Shape: (1, H, W), values in [0, 1]
|
171 |
+
|
172 |
+
# Ensure the shape matches expected input (add batch dimension if needed)
|
173 |
+
mask_tensor = mask_tensor[:1, :, :] # Keep only one channel (grayscale)
|
174 |
+
mask_tensor = mask_tensor.to(device, dtype) # Send to correct device and dtype
|
175 |
+
|
176 |
+
# Convert to PIL image if needed for visualization
|
177 |
+
mask_image = torch_to_pil(mask_tensor)
|
178 |
|
179 |
+
mask, masked_spec = prepare_mask_and_masked_image(norm_spec, mask_tensor)
|
180 |
masked_spec_image = torch_to_pil(masked_spec)
|
181 |
|
182 |
# color masked spec and paint masked area to black
|