chansung commited on
Commit
60451f1
·
verified ·
1 Parent(s): 3df3cdb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import tarfile
2
  import wandb
3
 
@@ -10,6 +11,7 @@ from transformers import ViTFeatureExtractor
10
  PRETRAIN_CHECKPOINT = "google/vit-base-patch16-224-in21k"
11
  feature_extractor = ViTFeatureExtractor.from_pretrained(PRETRAIN_CHECKPOINT)
12
 
 
13
  MODEL = None
14
 
15
  RESOLTUION = 224
@@ -43,11 +45,11 @@ def preprocess_input(image):
43
  "pixel_values": tf.expand_dims(image, 0)
44
  }
45
 
46
- def get_predictions(wb_token, image):
47
  global MODEL
48
 
49
  if MODEL is None:
50
- wandb.login(key=wb_token)
51
  wandb.init(project="tfx-vit-pipeline", id="gvtyqdgn", resume=True)
52
  path = wandb.use_artifact('tfx-vit-pipeline/final_model:1688113391', type='model').download()
53
 
@@ -66,8 +68,6 @@ def get_predictions(wb_token, image):
66
  with gr.Blocks() as demo:
67
  gr.Markdown("## Simple demo for a Image Classification of the Beans Dataset with HF ViT model")
68
 
69
- wb_token_if = gr.Textbox(interactive=True, label="Your Weight & Biases API Key")
70
-
71
  with gr.Row():
72
  image_if = gr.Image()
73
  label_if = gr.Label(num_top_classes=3)
@@ -76,7 +76,7 @@ with gr.Blocks() as demo:
76
 
77
  classify_if.click(
78
  get_predictions,
79
- [wb_token_if, image_if],
80
  label_if
81
  )
82
 
 
1
+ import os
2
  import tarfile
3
  import wandb
4
 
 
11
  PRETRAIN_CHECKPOINT = "google/vit-base-patch16-224-in21k"
12
  feature_extractor = ViTFeatureExtractor.from_pretrained(PRETRAIN_CHECKPOINT)
13
 
14
+ WB_KEY = os.environ['WB_KEY']
15
  MODEL = None
16
 
17
  RESOLTUION = 224
 
45
  "pixel_values": tf.expand_dims(image, 0)
46
  }
47
 
48
+ def get_predictions(image):
49
  global MODEL
50
 
51
  if MODEL is None:
52
+ wandb.login(key=WB_KEY)
53
  wandb.init(project="tfx-vit-pipeline", id="gvtyqdgn", resume=True)
54
  path = wandb.use_artifact('tfx-vit-pipeline/final_model:1688113391', type='model').download()
55
 
 
68
  with gr.Blocks() as demo:
69
  gr.Markdown("## Simple demo for a Image Classification of the Beans Dataset with HF ViT model")
70
 
 
 
71
  with gr.Row():
72
  image_if = gr.Image()
73
  label_if = gr.Label(num_top_classes=3)
 
76
 
77
  classify_if.click(
78
  get_predictions,
79
+ image_if,
80
  label_if
81
  )
82