Spaces:
Paused
Paused
File size: 3,815 Bytes
b9f0115 e706d2b f43d01b e706d2b b9f0115 29de13b f43d01b a5f7c39 f43d01b dc5a8c5 a770601 998456b f43d01b 77bb1b9 f43d01b b9f0115 77bb1b9 0ecdec8 77bb1b9 b9f0115 a83472e b9f0115 e706d2b 5dc6bd9 8d19b59 b9f0115 5f32f95 5dc6bd9 b9f0115 5dc6bd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import numpy as np
import torch
import sys
import os
from fastai.vision.all import *
import gradio as gr
############### HF ###########################
HF_TOKEN = os.getenv('HF_TOKEN')
hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "savtadepth-flags")
############## DVC ################################
PROD_MODEL_PATH = "src/models"
TRAIN_PATH = "src/data/processed/train/bathroom"
TEST_PATH = "src/data/processed/test/bathroom"
if os.path.isdir(".dvc"):
print("Running DVC")
os.system("dvc config cache.type copy")
os.system("dvc config core.no_scm true")
if os.system(f"dvc pull {PROD_MODEL_PATH} {TRAIN_PATH } {TEST_PATH }") != 0:
exit("dvc pull failed")
os.system("rm -r .dvc")
# .apt/usr/lib/dvc
############## Inference ##############################
class ImageImageDataLoaders(DataLoaders):
"""Basic wrapper around several `DataLoader`s with factory methods for Image to Image problems"""
@classmethod
@delegates(DataLoaders.from_dblock)
def from_label_func(cls, path, filenames, label_func, valid_pct=0.2, seed=None, item_transforms=None,
batch_transforms=None, **kwargs):
"""Create from list of `fnames` in `path`s with `label_func`."""
datablock = DataBlock(blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImageBW)),
get_y=label_func,
splitter=RandomSplitter(valid_pct, seed=seed),
item_tfms=item_transforms,
batch_tfms=batch_transforms)
res = cls.from_dblock(datablock, filenames, path=path, **kwargs)
return res
def get_y_fn(x):
y = str(x.absolute()).replace('.jpg', '_depth.png')
y = Path(y)
return y
def create_data(data_path):
fnames = get_files(data_path/'train', extensions='.jpg')
data = ImageImageDataLoaders.from_label_func(data_path/'train', seed=42, bs=4, num_workers=0, filenames=fnames, label_func=get_y_fn)
return data
data = create_data(Path('src/data/processed'))
learner = unet_learner(data,resnet34, metrics=rmse, wd=1e-2, n_out=3, loss_func=MSELossFlat(), path='src/')
learner.load('model')
def gen(input_img):
return PILImageBW.create((learner.predict(input_img))[0]).convert('L')
################### Gradio Web APP ################################
title = "SavtaDepth WebApp"
description = "Savta Depth is a collaborative Open Source Data Science project for monocular depth estimation - Turn 2d photos into 3d photos. To test the model and code please check out the link bellow."
article = "<p style='text-align: center'><a href='https://dagshub.com/OperationSavta/SavtaDepth' target='_blank'>SavtaDepth Project from OperationSavta</a></p><p style='text-align: center'><a href='https://colab.research.google.com/drive/1XU4DgQ217_hUMU1dllppeQNw3pTRlHy1?usp=sharing' target='_blank'>Google Colab Demo</a></p></center></p>"
examples = [
["examples/00008.jpg"],
["examples/00045.jpg"],
]
favicon = "examples/favicon.ico"
thumbnail = "examples/SavtaDepth.png"
def main():
iface = gr.Interface(
gen,
gr.inputs.Image(shape=(640,480),type='numpy'),
"image",
title = title,
flagging_options=["incorrect", "worst","ambiguous"],
allow_flagging = "manual",
flagging_callback=hf_writer,
description = description,
article = article,
examples = examples,
theme ="peach",
allow_screenshot=True
)
iface.launch(enable_queue=True)
# enable_queue=True,auth=("admin", "pass1234")
if __name__ == '__main__':
main() |