fmajer commited on
Commit
2c5aba6
·
1 Parent(s): 9e7f571

Object detection app

Browse files
Files changed (6) hide show
  1. .DS_Store +0 -0
  2. app.py +67 -0
  3. examples/img1.jpeg +0 -0
  4. get_output.py +46 -0
  5. model.py +40 -0
  6. requirements.txt +10 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import cv2
4
+ import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image
7
+ import transformers
8
+ from transformers import RobertaModel, RobertaTokenizer
9
+ import timm
10
+ import pandas as pd
11
+ import matplotlib.pyplot as plt
12
+ from timm.data import resolve_data_config
13
+ from timm.data.transforms_factory import create_transform
14
+
15
+ from model import Model
16
+ import get_output
17
+
18
+
19
+ # Use GPU if available
20
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
+
22
+ # Initialize models
23
+ vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0, global_pool='').to(device)
24
+ tokenizer = RobertaTokenizer.from_pretrained('roberta-base', truncation=True, do_lower_case=True)
25
+ roberta = RobertaModel.from_pretrained("roberta-base")
26
+ model = Model(vit, roberta, tokenizer, device).to(device)
27
+ model.eval()
28
+
29
+ config = resolve_data_config({}, model=vit)
30
+ config['no_aug'] = True
31
+ config['interpolation'] = 'bilinear'
32
+ transform = create_transform(**config)
33
+
34
+
35
+ def query_image(input_img, query, binarize, eval_threshold):
36
+
37
+ PIL_image = Image.fromarray(input_img, "RGB")
38
+ img = transform(PIL_image)
39
+ img = torch.unsqueeze(img,0).to(device)
40
+
41
+ with torch.no_grad():
42
+ output = model(img, query)
43
+
44
+ img = visualize_output(img, output, binarize, eval_threshold)
45
+ return img
46
+
47
+
48
+ description = """
49
+ Gradio demo for an object detection architecture,
50
+ introduced in <a href="https://arxiv.org/abs/2205.06230">my bachelor thesis</a>.
51
+ \n\nLorem ipsum ....
52
+ *"image of a shoe"*. Refer to the <a href="https://arxiv.org/abs/2103.00020">CLIP</a> paper to see the full list of text templates used to augment the training data.
53
+ """
54
+ demo = gr.Interface(
55
+ query_image,
56
+ inputs=[gr.Image(), "text", "checkbox", gr.Slider(0, 1, value=0.25)],
57
+ outputs="image",
58
+ title="Object Detection Using Textual Queries",
59
+ description=description,
60
+ examples=[
61
+ ["examples/img1.jpeg", "Find a person.", True, 0.25],
62
+ ],
63
+ allow_flagging = "never",
64
+ cache_examples=False,
65
+ )
66
+ demo.launch(debug=True)
67
+
examples/img1.jpeg ADDED
get_output.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ from PIL import Image
5
+
6
+
7
+ def preprocess(image, output, binarize, threshold):
8
+
9
+ image = image.cpu().detach().numpy().squeeze()
10
+ image = np.transpose(image,(1,2,0))
11
+ image = (image + 1) * 0.5
12
+ output = output.cpu().detach().numpy().squeeze()
13
+
14
+ if binarize:
15
+ output = np.where(output > threshold, 1., 0.)
16
+
17
+ return image, output
18
+
19
+
20
+ def enlarge_array(output):
21
+ df = pd.DataFrame(np.reshape(output, (14,14)))
22
+ df = pd.DataFrame(np.repeat(df.values, 16, axis=0))
23
+ df = pd.DataFrame(np.repeat(df.values, 16, axis=1))
24
+ output = df.to_numpy()
25
+
26
+ return output
27
+
28
+ def visualize_output(image, output, binarize, threshold):
29
+
30
+ image, output = preprocess(image, output, binarize, threshold)
31
+ output = enlarge_array(output)
32
+ output_mask = Image.fromarray(output * 255)
33
+
34
+ fig = plt.figure(figsize = (6,6))
35
+ plt.axis('off')
36
+ plt.imshow(image)
37
+ if binarize:
38
+ plt.imshow(output_mask, alpha=.67)
39
+ else:
40
+ plt.imshow(output_mask, alpha=.8)
41
+ fig.tight_layout(pad=0)
42
+ fig.canvas.draw()
43
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
44
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
45
+
46
+ return data
model.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class Model(nn.Module):
6
+ def __init__(self, vit, roberta, tokenizer, device):
7
+ super().__init__()
8
+ self.bertmap = nn.Conv1d(768, 768, 1)
9
+ self.vitmap = nn.Conv1d(768, 768, 1)
10
+ self.conv1d = nn.Conv1d(1, 1, 1)
11
+ self.add_module("vit", vit)
12
+ self.add_module("roberta", roberta)
13
+ self.tokenizer = tokenizer
14
+ self.conv1d.weight = torch.nn.Parameter(torch.tensor([[[1.]]]))
15
+ self.conv1d.bias = torch.nn.Parameter(torch.tensor([0.]))
16
+ self.device = device
17
+
18
+ def forward(self, image, cats):
19
+ vit_out = self.vit(image)
20
+ vit_out = vit_out[:,1:vit_out.shape[1],:]
21
+ vit_out = torch.transpose(vit_out, 2,1)
22
+ vit_out = self.vitmap(vit_out)
23
+ vit_out = torch.transpose(vit_out, 2,1)
24
+ token_out = self.tokenizer.encode_plus(
25
+ cats,
26
+ padding=True,
27
+ add_special_tokens=True,
28
+ return_token_type_ids=True,
29
+ return_tensors='pt'
30
+ ).to(self.device)
31
+ bert_out = self.roberta(**token_out)
32
+ hidden_state = bert_out.last_hidden_state
33
+ hidden_state = torch.transpose(hidden_state, 2,1)
34
+ hidden_state = self.bertmap(hidden_state)
35
+ hidden_state = torch.transpose(hidden_state, 2,1)
36
+ pooled_bert_out = hidden_state[:, 0]
37
+ pooled_bert_out = torch.unsqueeze(pooled_bert_out, dim=2)
38
+ out = torch.matmul(vit_out, pooled_bert_out)
39
+ out = torch.transpose(out, 2,1)
40
+ return torch.squeeze(self.conv1d(out), dim=1)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install -r requirements.txt
2
+
3
+ numpy>=1.18.5
4
+ torch>=1.7.0
5
+ torchvision>=0.8.1
6
+ git+https://github.com/huggingface/transformers.git
7
+ opencv-python
8
+ pandas
9
+ matplotlib
10
+ timm