yashikag commited on
Commit
58e7332
·
1 Parent(s): 7fa19dc

Upload image_caption.py

Browse files
Files changed (1) hide show
  1. image_caption.py +78 -0
image_caption.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """image_caption
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1wo4dOccibBJyLj9E3anSLGeMCWbnIPS1
8
+ """
9
+
10
+ #pip install transformers -q
11
+
12
+ #pip install gradio -q
13
+
14
+
15
+
16
+ from PIL import Image
17
+ from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, PreTrainedTokenizerFast
18
+ import requests
19
+
20
+ model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
21
+
22
+ vit_feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
23
+
24
+ tokenizer = PreTrainedTokenizerFast.from_pretrained("distilgpt2")
25
+
26
+ #url = 'https://d2gp644kobdlm6.cloudfront.net/wp-content/uploads/2016/06/bigstock-Shocked-and-surprised-boy-on-t-113798588-300x212.jpg'
27
+
28
+ #with Image.open(requests.get(url, stream=True).raw) as img:
29
+ #pixel_values = vit_feature_extractor(images=img, return_tensors="pt").pixel_values
30
+
31
+ #encoder_outputs = model.generate(pixel_values.to('cpu'),num_beams=5)
32
+
33
+ #generated_sentences = tokenizer.batch_decode(encoder_outputs, skip_special_tokens=True)
34
+
35
+ #generated_sentences
36
+
37
+ #naive text processing
38
+ #generated_sentences[0].split('.')[0]
39
+
40
+ # inference function
41
+
42
+ def vit2distilgpt2(img):
43
+ pixel_values = vit_feature_extractor(images=img, return_tensors="pt").pixel_values
44
+ encoder_outputs = generated_ids = model.generate(pixel_values.to('cpu'),num_beams=5)
45
+ generated_sentences = tokenizer.batch_decode(encoder_outputs, skip_special_tokens=True)
46
+
47
+ return(generated_sentences[0].split('.')[0])
48
+
49
+ !wget https://media.glamour.com/photos/5f171c4fd35176eaedb36823/master/w_2560%2Cc_limit/bike.jpg
50
+
51
+ import gradio as gr
52
+
53
+ inputs = [
54
+ gr.inputs.Image(type="pil", label="Original Image")
55
+ ]
56
+
57
+ outputs = [
58
+ gr.outputs.Textbox(label = 'Caption')
59
+ ]
60
+
61
+ title = "Image Captioning using ViT + GPT2"
62
+ description = "ViT and GPT2 are used to generate Image Caption for the uploaded image. COCO Dataset was used for training. This image captioning model might have some biases that we couldn't figure during our stress testing, so if you find any bias (gender, race and so on) please use `Flag` button to flag the image with bias"
63
+ article = " <a href='https://huggingface.co/vit2distilgpt2'>Model Repo on Hugging Face Model Hub</a>"
64
+ examples = [
65
+ ["bike.jpg"]
66
+
67
+ ]
68
+
69
+ gr.Interface(
70
+ vit2distilgpt2,
71
+ inputs,
72
+ outputs,
73
+ title=title,
74
+ description=description,
75
+ article=article,
76
+ examples=examples,
77
+ theme="huggingface",
78
+ ).launch(debug=True, enable_queue=True)