import streamlit as st from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer import torch from PIL import Image # Load pre-trained models and tokenizer model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") # Check device and move model to the appropriate device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Set generation parameters max_length = 16 num_beams = 4 gen_kwargs = {"max_length": max_length, "num_beams": num_beams} # Define the prediction function def predict_caption(image): if image.mode != "RGB": image = image.convert(mode="RGB") # Resize image to a fixed size (ViT typically requires 224x224 or 384x384) image = image.resize((384, 384)) # Resize to 384x384 for ViT # Process image with padding enabled to handle batched tensor conversion pixel_values = feature_extractor(images=[image], return_tensors="pt", padding=True).pixel_values.to(device) # Generate caption output_ids = model.generate(pixel_values, **gen_kwargs) preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) # Return the caption text return preds[0].strip() # Main function for Streamlit app def main(): st.title("Image Caption Generator") st.write("Upload an image, and the model will describe what it sees.") # Upload image uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: # Load and display the uploaded image image = Image.open(uploaded_file) st.image(image, caption='Uploaded Image', use_column_width=True) # Generate and display caption caption = predict_caption(image) st.write("Caption:", caption) # Run the application if __name__ == "__main__": main()