fmajer commited on
Commit
a634e56
·
1 Parent(s): 5392e1d

css change

Browse files
Files changed (3) hide show
  1. app.py +8 -5
  2. model.py +0 -1
  3. get_output.py → output.py +3 -2
app.py CHANGED
@@ -13,7 +13,7 @@ from timm.data import resolve_data_config
13
  from timm.data.transforms_factory import create_transform
14
 
15
  from model import Model
16
- from get_output import visualize_output
17
 
18
 
19
  # Use GPU if available
@@ -30,13 +30,13 @@ model.eval()
30
  state = torch.load('saved_model', map_location=torch.device('cpu'))
31
  model.load_state_dict(state['val_model_dict'])
32
 
33
- # Transform for input image
34
  config = resolve_data_config({}, model=vit)
35
  config['no_aug'] = True
36
  config['interpolation'] = 'bilinear'
37
  transform = create_transform(**config)
38
 
39
-
40
  def query_image(input_img, query, binarize, eval_threshold):
41
 
42
  PIL_image = Image.fromarray(input_img, "RGB")
@@ -49,10 +49,10 @@ def query_image(input_img, query, binarize, eval_threshold):
49
  img = visualize_output(img, output, binarize, eval_threshold)
50
  return img
51
 
52
-
53
  description = """
54
  Gradio demo for an object detection architecture,
55
- introduced in <a href="https://arxiv.org/abs/2205.06230">my bachelor thesis</a>.
56
  \n\nLorem ipsum ....
57
  *"image of a shoe"*. Refer to the <a href="https://arxiv.org/abs/2103.00020">CLIP</a> paper to see the full list of text templates used to augment the training data.
58
  """
@@ -67,6 +67,9 @@ demo = gr.Interface(
67
  ],
68
  allow_flagging = "never",
69
  cache_examples=False,
 
 
 
70
  )
71
  demo.launch(debug=True)
72
 
 
13
  from timm.data.transforms_factory import create_transform
14
 
15
  from model import Model
16
+ from output import visualize_output
17
 
18
 
19
  # Use GPU if available
 
30
  state = torch.load('saved_model', map_location=torch.device('cpu'))
31
  model.load_state_dict(state['val_model_dict'])
32
 
33
+ # Create transform for input image
34
  config = resolve_data_config({}, model=vit)
35
  config['no_aug'] = True
36
  config['interpolation'] = 'bilinear'
37
  transform = create_transform(**config)
38
 
39
+ # Inference function
40
  def query_image(input_img, query, binarize, eval_threshold):
41
 
42
  PIL_image = Image.fromarray(input_img, "RGB")
 
49
  img = visualize_output(img, output, binarize, eval_threshold)
50
  return img
51
 
52
+ # Gradio interface
53
  description = """
54
  Gradio demo for an object detection architecture,
55
+ introduced in <a href="https://www.google.com/">my bachelor thesis (link will be added)</a>.
56
  \n\nLorem ipsum ....
57
  *"image of a shoe"*. Refer to the <a href="https://arxiv.org/abs/2103.00020">CLIP</a> paper to see the full list of text templates used to augment the training data.
58
  """
 
67
  ],
68
  allow_flagging = "never",
69
  cache_examples=False,
70
+ css = """
71
+ body {background-color : grey}
72
+ """,
73
  )
74
  demo.launch(debug=True)
75
 
model.py CHANGED
@@ -1,7 +1,6 @@
1
  import torch
2
  import torch.nn as nn
3
 
4
-
5
  class Model(nn.Module):
6
  def __init__(self, vit, roberta, tokenizer, device):
7
  super().__init__()
 
1
  import torch
2
  import torch.nn as nn
3
 
 
4
  class Model(nn.Module):
5
  def __init__(self, vit, roberta, tokenizer, device):
6
  super().__init__()
get_output.py → output.py RENAMED
@@ -25,6 +25,7 @@ def enlarge_array(output):
25
 
26
  return output
27
 
 
28
  def visualize_output(image, output, binarize, threshold):
29
 
30
  image, output = preprocess(image, output, binarize, threshold)
@@ -35,9 +36,9 @@ def visualize_output(image, output, binarize, threshold):
35
  plt.axis('off')
36
  plt.imshow(image)
37
  if binarize:
38
- plt.imshow(output_mask, alpha=.67)
39
  else:
40
- plt.imshow(output_mask, alpha=.8)
41
  fig.tight_layout(pad=0)
42
  fig.canvas.draw()
43
  data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
 
25
 
26
  return output
27
 
28
+
29
  def visualize_output(image, output, binarize, threshold):
30
 
31
  image, output = preprocess(image, output, binarize, threshold)
 
36
  plt.axis('off')
37
  plt.imshow(image)
38
  if binarize:
39
+ plt.imshow(output_mask, alpha=.5)
40
  else:
41
+ plt.imshow(output_mask, alpha=.6)
42
  fig.tight_layout(pad=0)
43
  fig.canvas.draw()
44
  data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)