File size: 2,052 Bytes
d3e0996 4f343b0 319d3d6 d3e0996 319d3d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import streamlit as st
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import torch
from PIL import Image
# Load pre-trained models and tokenizer
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
# Check device and move model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Set generation parameters
max_length = 16
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
# Define the prediction function
def predict_caption(image):
if image.mode != "RGB":
image = image.convert(mode="RGB")
# Resize image to a fixed size (ViT typically requires 224x224 or 384x384)
image = image.resize((384, 384)) # Resize to 384x384 for ViT
# Process image with padding enabled to handle batched tensor conversion
pixel_values = feature_extractor(images=[image], return_tensors="pt", padding=True).pixel_values.to(device)
# Generate caption
output_ids = model.generate(pixel_values, **gen_kwargs)
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
# Return the caption text
return preds[0].strip()
# Main function for Streamlit app
def main():
st.title("Image Caption Generator")
st.write("Upload an image, and the model will describe what it sees.")
# Upload image
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Load and display the uploaded image
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image', use_column_width=True)
# Generate and display caption
caption = predict_caption(image)
st.write("Caption:", caption)
# Run the application
if __name__ == "__main__":
main()
|