|
import streamlit as st |
|
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer |
|
import torch |
|
from PIL import Image |
|
|
|
|
|
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") |
|
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") |
|
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
|
|
max_length = 16 |
|
num_beams = 4 |
|
gen_kwargs = {"max_length": max_length, "num_beams": num_beams} |
|
|
|
|
|
def predict_caption(image): |
|
if image.mode != "RGB": |
|
image = image.convert(mode="RGB") |
|
|
|
|
|
image = image.resize((384, 384)) |
|
|
|
|
|
pixel_values = feature_extractor(images=[image], return_tensors="pt", padding=True).pixel_values.to(device) |
|
|
|
|
|
output_ids = model.generate(pixel_values, **gen_kwargs) |
|
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
|
|
|
|
|
return preds[0].strip() |
|
|
|
|
|
def main(): |
|
st.title("Image Caption Generator") |
|
st.write("Upload an image, and the model will describe what it sees.") |
|
|
|
|
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_file is not None: |
|
|
|
image = Image.open(uploaded_file) |
|
st.image(image, caption='Uploaded Image', use_column_width=True) |
|
|
|
|
|
caption = predict_caption(image) |
|
st.write("Caption:", caption) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|