css change
Browse files- app.py +8 -5
- model.py +0 -1
- get_output.py → output.py +3 -2
app.py
CHANGED
@@ -13,7 +13,7 @@ from timm.data import resolve_data_config
|
|
13 |
from timm.data.transforms_factory import create_transform
|
14 |
|
15 |
from model import Model
|
16 |
-
from
|
17 |
|
18 |
|
19 |
# Use GPU if available
|
@@ -30,13 +30,13 @@ model.eval()
|
|
30 |
state = torch.load('saved_model', map_location=torch.device('cpu'))
|
31 |
model.load_state_dict(state['val_model_dict'])
|
32 |
|
33 |
-
#
|
34 |
config = resolve_data_config({}, model=vit)
|
35 |
config['no_aug'] = True
|
36 |
config['interpolation'] = 'bilinear'
|
37 |
transform = create_transform(**config)
|
38 |
|
39 |
-
|
40 |
def query_image(input_img, query, binarize, eval_threshold):
|
41 |
|
42 |
PIL_image = Image.fromarray(input_img, "RGB")
|
@@ -49,10 +49,10 @@ def query_image(input_img, query, binarize, eval_threshold):
|
|
49 |
img = visualize_output(img, output, binarize, eval_threshold)
|
50 |
return img
|
51 |
|
52 |
-
|
53 |
description = """
|
54 |
Gradio demo for an object detection architecture,
|
55 |
-
introduced in <a href="https://
|
56 |
\n\nLorem ipsum ....
|
57 |
*"image of a shoe"*. Refer to the <a href="https://arxiv.org/abs/2103.00020">CLIP</a> paper to see the full list of text templates used to augment the training data.
|
58 |
"""
|
@@ -67,6 +67,9 @@ demo = gr.Interface(
|
|
67 |
],
|
68 |
allow_flagging = "never",
|
69 |
cache_examples=False,
|
|
|
|
|
|
|
70 |
)
|
71 |
demo.launch(debug=True)
|
72 |
|
|
|
13 |
from timm.data.transforms_factory import create_transform
|
14 |
|
15 |
from model import Model
|
16 |
+
from output import visualize_output
|
17 |
|
18 |
|
19 |
# Use GPU if available
|
|
|
30 |
state = torch.load('saved_model', map_location=torch.device('cpu'))
|
31 |
model.load_state_dict(state['val_model_dict'])
|
32 |
|
33 |
+
# Create transform for input image
|
34 |
config = resolve_data_config({}, model=vit)
|
35 |
config['no_aug'] = True
|
36 |
config['interpolation'] = 'bilinear'
|
37 |
transform = create_transform(**config)
|
38 |
|
39 |
+
# Inference function
|
40 |
def query_image(input_img, query, binarize, eval_threshold):
|
41 |
|
42 |
PIL_image = Image.fromarray(input_img, "RGB")
|
|
|
49 |
img = visualize_output(img, output, binarize, eval_threshold)
|
50 |
return img
|
51 |
|
52 |
+
# Gradio interface
|
53 |
description = """
|
54 |
Gradio demo for an object detection architecture,
|
55 |
+
introduced in <a href="https://www.google.com/">my bachelor thesis (link will be added)</a>.
|
56 |
\n\nLorem ipsum ....
|
57 |
*"image of a shoe"*. Refer to the <a href="https://arxiv.org/abs/2103.00020">CLIP</a> paper to see the full list of text templates used to augment the training data.
|
58 |
"""
|
|
|
67 |
],
|
68 |
allow_flagging = "never",
|
69 |
cache_examples=False,
|
70 |
+
css = """
|
71 |
+
body {background-color : grey}
|
72 |
+
""",
|
73 |
)
|
74 |
demo.launch(debug=True)
|
75 |
|
model.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
|
4 |
-
|
5 |
class Model(nn.Module):
|
6 |
def __init__(self, vit, roberta, tokenizer, device):
|
7 |
super().__init__()
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
|
|
|
4 |
class Model(nn.Module):
|
5 |
def __init__(self, vit, roberta, tokenizer, device):
|
6 |
super().__init__()
|
get_output.py → output.py
RENAMED
@@ -25,6 +25,7 @@ def enlarge_array(output):
|
|
25 |
|
26 |
return output
|
27 |
|
|
|
28 |
def visualize_output(image, output, binarize, threshold):
|
29 |
|
30 |
image, output = preprocess(image, output, binarize, threshold)
|
@@ -35,9 +36,9 @@ def visualize_output(image, output, binarize, threshold):
|
|
35 |
plt.axis('off')
|
36 |
plt.imshow(image)
|
37 |
if binarize:
|
38 |
-
plt.imshow(output_mask, alpha=.
|
39 |
else:
|
40 |
-
plt.imshow(output_mask, alpha=.
|
41 |
fig.tight_layout(pad=0)
|
42 |
fig.canvas.draw()
|
43 |
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
|
|
25 |
|
26 |
return output
|
27 |
|
28 |
+
|
29 |
def visualize_output(image, output, binarize, threshold):
|
30 |
|
31 |
image, output = preprocess(image, output, binarize, threshold)
|
|
|
36 |
plt.axis('off')
|
37 |
plt.imshow(image)
|
38 |
if binarize:
|
39 |
+
plt.imshow(output_mask, alpha=.5)
|
40 |
else:
|
41 |
+
plt.imshow(output_mask, alpha=.6)
|
42 |
fig.tight_layout(pad=0)
|
43 |
fig.canvas.draw()
|
44 |
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|