Spaces:
Sleeping
Sleeping
David Vaillant
commited on
Commit
·
a073fdd
1
Parent(s):
5b4a37c
Basic func.
Browse files- baby_shiny.py +102 -0
- backend.py +72 -0
- checkpoints/bbox_finetune.ckpt +3 -0
baby_shiny.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from shiny import App, Inputs, Outputs, Session, reactive, render, ui
|
2 |
+
from shiny.types import FileInfo, ImgData
|
3 |
+
import asyncio
|
4 |
+
import concurrent.futures
|
5 |
+
|
6 |
+
import backend
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image, ImageDraw
|
10 |
+
from pathlib import Path
|
11 |
+
import tempfile
|
12 |
+
|
13 |
+
|
14 |
+
def draw_layer_on_image(im: Image) -> Image:
|
15 |
+
"""Draws something on top of an image."""
|
16 |
+
# Attempting to use thresholds.
|
17 |
+
threshold: int = 1
|
18 |
+
output_im = np.array(im)
|
19 |
+
# return Image.fromarray(output_im)
|
20 |
+
|
21 |
+
# The image drawing code.
|
22 |
+
draw = ImageDraw.Draw(im)
|
23 |
+
draw.line((0, 0) + im.size, fill=128, width=5)
|
24 |
+
draw.line((0, im.size[1], im.size[0], 0), fill=128)
|
25 |
+
|
26 |
+
return im
|
27 |
+
|
28 |
+
|
29 |
+
# UI:
|
30 |
+
# TITLE ELEMENT, centered
|
31 |
+
# input, centered.
|
32 |
+
# table in middle. Upload, displays image on the left.
|
33 |
+
# arrow in the middle, mask on the right.
|
34 |
+
card_height = '700px'
|
35 |
+
app_ui = ui.page_fixed(
|
36 |
+
ui.input_file("file1", "Upload a sidewalk.", accept=[".jpg", ".png", ".jpeg"], multiple=False),
|
37 |
+
ui.layout_columns(
|
38 |
+
ui.card(
|
39 |
+
ui.card_header("Uploaded Image"),
|
40 |
+
ui.output_image("show_image"),
|
41 |
+
height=card_height
|
42 |
+
),
|
43 |
+
ui.card(
|
44 |
+
ui.card_header("Image Mask"),
|
45 |
+
# ui.input_task_button("mask_btn", "Process mask"),
|
46 |
+
ui.output_image("samwalk"),
|
47 |
+
height=card_height
|
48 |
+
),
|
49 |
+
)
|
50 |
+
)
|
51 |
+
|
52 |
+
def strip_alpha(image: Image) -> Image:
|
53 |
+
# Create a white background
|
54 |
+
background = Image.new('RGBA', image.size, (255, 255, 255, 255))
|
55 |
+
composite = Image.alpha_composite(background, image)
|
56 |
+
rgb_image = composite.convert('RGB')
|
57 |
+
return rgb_image
|
58 |
+
|
59 |
+
def server(input: Inputs, output: Outputs, session: Session):
|
60 |
+
uploaded_img = None
|
61 |
+
|
62 |
+
@reactive.calc
|
63 |
+
def parsed_file():
|
64 |
+
file: list[FileInfo] | None = input.file1()
|
65 |
+
if file is None:
|
66 |
+
return
|
67 |
+
return file[0]
|
68 |
+
|
69 |
+
@render.image
|
70 |
+
def show_image():
|
71 |
+
uploaded_img = parsed_file()
|
72 |
+
if uploaded_img is None:
|
73 |
+
return
|
74 |
+
uploaded_src = uploaded_img['datapath']
|
75 |
+
img: ImgData = {"src": str(uploaded_src), "width": "500px"}
|
76 |
+
return img
|
77 |
+
|
78 |
+
# @reactive.event(input.mask_btn)
|
79 |
+
@render.image
|
80 |
+
def samwalk():
|
81 |
+
uploaded_file = parsed_file()
|
82 |
+
if uploaded_file is None:
|
83 |
+
return
|
84 |
+
uploaded_src = uploaded_file['datapath']
|
85 |
+
uploaded_img = Image.open(uploaded_src)
|
86 |
+
if uploaded_img.mode == 'RGBA':
|
87 |
+
uploaded_img = strip_alpha(uploaded_img)
|
88 |
+
dirpath = tempfile.mkdtemp()
|
89 |
+
|
90 |
+
# output_img = async_process_image(uploaded_img)
|
91 |
+
# while output_img is None:
|
92 |
+
# pass
|
93 |
+
# output_img = output_img.result()
|
94 |
+
# # return {"src": str("waiting.gif"), "width": "500px"}
|
95 |
+
output_img = backend.process_image(uploaded_img)
|
96 |
+
output_path = dirpath / Path(uploaded_src)
|
97 |
+
output_img.save(output_path)
|
98 |
+
return {"src": str(output_path), "width": "500px"}
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
app = App(app_ui, server)
|
backend.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# backend.py
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image, ImageDraw
|
4 |
+
import torch
|
5 |
+
from transformers import SamModel, SamProcessor
|
6 |
+
from torchvision.transforms import v2
|
7 |
+
from samgeo.text_sam import LangSAM
|
8 |
+
import os
|
9 |
+
import logging
|
10 |
+
|
11 |
+
|
12 |
+
preproc = v2.Compose([
|
13 |
+
v2.PILToTensor(),
|
14 |
+
v2.ToDtype(torch.float32, scale=True), # to float32 in [0, 1]
|
15 |
+
])
|
16 |
+
|
17 |
+
|
18 |
+
# Load the necessary models.
|
19 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
20 |
+
CHECKPOINT_FILE = os.getenv("SAM_FINETUNE_CHECKPOINT", "checkpoints/bbox_finetune.pth")
|
21 |
+
|
22 |
+
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
|
23 |
+
tuned_model = SamModel.from_pretrained("facebook/sam-vit-large").to(device)
|
24 |
+
tuned_model.load_state_dict(torch.load(CHECKPOINT_FILE,
|
25 |
+
map_location=device))
|
26 |
+
langsam_model = LangSAM("vit_l")
|
27 |
+
|
28 |
+
|
29 |
+
def process_image(image: Image, bbox: list[int, int, int, int] = None) -> Image:
|
30 |
+
logging.info("Logging image information.")
|
31 |
+
if bbox is None:
|
32 |
+
# No bbox information. Use default (filters out zeroes)
|
33 |
+
logging.debug("Using default, null bounding box.")
|
34 |
+
bbox = list(map(float, image.getbbox())) # List of floats.
|
35 |
+
inputs = processor(preproc(image), input_boxes=[[bbox]],
|
36 |
+
do_rescale=False, return_tensors="pt")
|
37 |
+
inputs = {k: v.to(device) for k, v in inputs.items()} # Map objects to our device.
|
38 |
+
|
39 |
+
mask = get_sidewalk_mask(tuned_model, inputs)
|
40 |
+
# Get tree masks.
|
41 |
+
# Union 'em??
|
42 |
+
return mask
|
43 |
+
|
44 |
+
|
45 |
+
def get_sidewalk_mask(model, inputs) -> Image:
|
46 |
+
logging.info("Calculating mask.")
|
47 |
+
model.eval()
|
48 |
+
with torch.no_grad():
|
49 |
+
outputs = model(**inputs, multimask_output=False)
|
50 |
+
## apply sigmoid
|
51 |
+
mask_probabilities = torch.sigmoid(outputs.pred_masks.squeeze(1))
|
52 |
+
## Convert to numpy for the rest of our stuff.
|
53 |
+
mask_probabilities = mask_probabilities.cpu().numpy().squeeze()
|
54 |
+
|
55 |
+
## Filter out smaller probs.
|
56 |
+
mask_probabilities[mask_probabilities < 0.5] = 0
|
57 |
+
|
58 |
+
## Map probabilities to color intensity linearly.
|
59 |
+
mask_probabilities *= 255
|
60 |
+
|
61 |
+
greyscale_img = Image.fromarray(mask_probabilities).convert('L')
|
62 |
+
return greyscale_img
|
63 |
+
|
64 |
+
|
65 |
+
def get_tree_masks(image: Image):
|
66 |
+
langsam_model.predict(image, "tree", box_threshold=0.24, text_threshold=0.24)
|
67 |
+
|
68 |
+
|
69 |
+
# masks, boxes, phrases, logits = tuned_model.predict(image_pil, bbox)
|
70 |
+
# tree_data = langsam_model.predict(image_pil, text_prompt)
|
71 |
+
|
72 |
+
# def draw_layer_on_image(model, im: Image, text_prompt: str='sidewalk') -> Image:
|
checkpoints/bbox_finetune.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4c72e371f7cd4644c9d9550649db4a5473ad63c21472b9d0973670d0dff1ff69
|
3 |
+
size 1249561500
|