Pretam1 commited on
Commit
d3e0996
·
verified ·
1 Parent(s): b0ede06

made new file

Browse files
Files changed (1) hide show
  1. app.py +54 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
3
+ import torch
4
+ from PIL import Image
5
+
6
+ # Load pre-trained models and tokenizer
7
+ model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
8
+ feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
9
+ tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
10
+
11
+ # Check device and move model to the appropriate device
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ model.to(device)
14
+
15
+ # Set generation parameters
16
+ max_length = 16
17
+ num_beams = 4
18
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
19
+
20
+ # Define the prediction function
21
+ def predict_caption(image):
22
+ if image.mode != "RGB":
23
+ image = image.convert(mode="RGB")
24
+
25
+ # Process image and move pixel values to device
26
+ pixel_values = feature_extractor(images=[image], return_tensors="pt").pixel_values.to(device)
27
+
28
+ # Generate caption
29
+ output_ids = model.generate(pixel_values, **gen_kwargs)
30
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
31
+
32
+ # Return the caption text
33
+ return preds[0].strip()
34
+
35
+ # Main function for Streamlit app
36
+ def main():
37
+ st.title("Image Caption Generator")
38
+ st.write("Upload an image, and the model will describe what it sees.")
39
+
40
+ # Upload image
41
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
42
+
43
+ if uploaded_file is not None:
44
+ # Load and display the uploaded image
45
+ image = Image.open(uploaded_file)
46
+ st.image(image, caption='Uploaded Image', use_column_width=True)
47
+
48
+ # Generate and display caption
49
+ caption = predict_caption(image)
50
+ st.write("Caption:", caption)
51
+
52
+ # Run the application
53
+ if __name__ == "__main__":
54
+ main()