randomarnab commited on
Commit
13f6a8b
·
1 Parent(s): 2f14636

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """app
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1Uvn7yZCyrMpOYNPb7K0G45tQZJVx8LyX
8
+ """
9
+
10
+ from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
11
+ import gradio as gr
12
+ import torch
13
+ from PIL import Image
14
+
15
+ model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
16
+ feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
17
+ tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
18
+
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ model.to(device)
21
+
22
+
23
+
24
+ max_length = 16
25
+ num_beams = 4
26
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
27
+
28
+ def predict_step(image):
29
+ # images = []
30
+ # for image_path in image_paths:
31
+ # i_image = Image.open(image_path)
32
+ # if i_image.mode != "RGB":
33
+ # i_image = i_image.convert(mode="RGB")
34
+
35
+ # images.append(i_image)
36
+
37
+ pixel_values = feature_extractor(images = image, return_tensors = "pt").pixel_values
38
+ pixel_values = pixel_values.to(device)
39
+
40
+ output_ids = model.generate(pixel_values, **gen_kwargs)
41
+
42
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
43
+ preds = [pred.strip() for pred in preds]
44
+ return preds
45
+
46
+ inputs = [ gr.inputs.Image(type = 'pil', label = 'Original Image')]
47
+ outputs = [ gr.outputs.Textbox(label = 'Caption')]
48
+ title = 'Image Captioning using ViT + GPT2'
49
+ description = 'ViT and GPT2 are used here to generate Image Caption for the user uploaded image.'
50
+ article = " <a href=' https://huggingface.co/sachin/vit2distilgpt2 '>Model Repository on Hugging Face Model Hub</a>"
51
+
52
+ gr.Interface(
53
+ predict_step,
54
+ inputs, outputs,
55
+ title = title,
56
+ description = description,
57
+ article = article,
58
+ theme = 'huggingface'
59
+ ).launch(debug = True, enable_queue = True)
60
+