Itanutiwari527's picture
Update app.py
4eea648 verified
raw
history blame contribute delete
3.7 kB
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from PIL import Image
import streamlit as st
from pathlib import Path
import requests
# Load pre-trained ResNet model
model = torchvision.models.resnet50(pretrained=True)
model.eval()
# Load ImageNet class labels
LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
labels = requests.get(LABELS_URL).text.strip().split("\n")
# Image transform
transform = transforms.Compose([
transforms.Resize((224, 224)), # Ensuring image is resized to (224, 224)
transforms.ToTensor(),
])
# Load predefined signs from 'signs' folder
signs_dir = Path("signs")
presets = {f.stem.replace("_", " "): f for f in signs_dir.glob("*")}
def preprocess_image(image):
image = image.convert("RGB")
tensor = transform(image).unsqueeze(0)
return tensor
def predict_class(tensor):
with torch.no_grad():
outputs = model(tensor)
_, predicted = outputs.max(1)
return labels[predicted.item()]
def generate_adversarial_example(tensor, epsilon, target_class=400):
tensor.requires_grad = True
outputs = model(tensor)
loss = F.cross_entropy(outputs, torch.tensor([target_class]))
loss.backward()
perturbation = epsilon * tensor.grad.sign()
perturbed_tensor = torch.clamp(tensor + perturbation, 0, 1)
return perturbed_tensor
def main():
st.title("Adversarial Attack on Traffic Signs")
st.write("Upload a traffic sign image or select a predefined one, then apply perturbation.")
col1, col2 = st.columns(2)
with col1:
uploaded_file = st.file_uploader("Upload a traffic sign image", type=["png", "jpg", "jpeg"])
with col2:
selected_preset = st.selectbox("Or choose a predefined sign", [None] + list(presets.keys()))
# User-defined perturbation strength limit
max_epsilon = st.number_input("Set Maximum Perturbation Strength (ε)", min_value=0.01, max_value=1.0, value=0.1, step=0.01)
epsilon = st.slider("Select Perturbation Strength (ε)", 0.0, float(max_epsilon), 0.01, step=0.01)
def load_image():
if uploaded_file:
return Image.open(uploaded_file).resize((224, 224)) # Resize to (224, 224)
elif selected_preset:
return Image.open(presets[selected_preset])
else:
return None
image = load_image()
if image:
tensor = preprocess_image(image)
original_label = predict_class(tensor)
# Generate perturbed image
perturbed_tensor = generate_adversarial_example(tensor.clone(), epsilon)
perturbed_image = transforms.ToPILImage()(perturbed_tensor.squeeze())
adversarial_label = predict_class(perturbed_tensor)
# Create two columns for displaying original vs adversarial
col1, col2 = st.columns(2)
# Show original image and prediction in the first column
with col1:
st.markdown("### Original Prediction")
st.success(original_label)
st.image(image, caption="Original Image", use_container_width=True)
# Show perturbed image and adversarial prediction in the second column
with col2:
st.markdown("### Adversarial Prediction")
st.error(adversarial_label)
# Ensure the perturbed image has the same size as the original image
perturbed_image_resized = perturbed_image.resize((224, 224)) # Resize to (224, 224)
st.image(perturbed_image_resized, caption="Perturbed Image", use_container_width=True)
else:
st.warning("Please upload image.")
if __name__ == "__main__":
main()