|
import torch |
|
import torchvision |
|
import torchvision.transforms as transforms |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
import streamlit as st |
|
from pathlib import Path |
|
import requests |
|
|
|
|
|
model = torchvision.models.resnet50(pretrained=True) |
|
model.eval() |
|
|
|
|
|
LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" |
|
labels = requests.get(LABELS_URL).text.strip().split("\n") |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
]) |
|
|
|
|
|
signs_dir = Path("signs") |
|
presets = {f.stem.replace("_", " "): f for f in signs_dir.glob("*")} |
|
|
|
def preprocess_image(image): |
|
image = image.convert("RGB") |
|
tensor = transform(image).unsqueeze(0) |
|
return tensor |
|
|
|
def predict_class(tensor): |
|
with torch.no_grad(): |
|
outputs = model(tensor) |
|
_, predicted = outputs.max(1) |
|
return labels[predicted.item()] |
|
|
|
def generate_adversarial_example(tensor, epsilon, target_class=400): |
|
tensor.requires_grad = True |
|
outputs = model(tensor) |
|
loss = F.cross_entropy(outputs, torch.tensor([target_class])) |
|
loss.backward() |
|
perturbation = epsilon * tensor.grad.sign() |
|
perturbed_tensor = torch.clamp(tensor + perturbation, 0, 1) |
|
return perturbed_tensor |
|
|
|
def main(): |
|
st.title("Adversarial Attack on Traffic Signs") |
|
st.write("Upload a traffic sign image or select a predefined one, then apply perturbation.") |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
uploaded_file = st.file_uploader("Upload a traffic sign image", type=["png", "jpg", "jpeg"]) |
|
with col2: |
|
selected_preset = st.selectbox("Or choose a predefined sign", [None] + list(presets.keys())) |
|
|
|
|
|
max_epsilon = st.number_input("Set Maximum Perturbation Strength (ε)", min_value=0.01, max_value=1.0, value=0.1, step=0.01) |
|
epsilon = st.slider("Select Perturbation Strength (ε)", 0.0, float(max_epsilon), 0.01, step=0.01) |
|
|
|
def load_image(): |
|
if uploaded_file: |
|
return Image.open(uploaded_file).resize((224, 224)) |
|
elif selected_preset: |
|
return Image.open(presets[selected_preset]) |
|
else: |
|
return None |
|
|
|
image = load_image() |
|
|
|
if image: |
|
tensor = preprocess_image(image) |
|
original_label = predict_class(tensor) |
|
|
|
|
|
perturbed_tensor = generate_adversarial_example(tensor.clone(), epsilon) |
|
perturbed_image = transforms.ToPILImage()(perturbed_tensor.squeeze()) |
|
adversarial_label = predict_class(perturbed_tensor) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
st.markdown("### Original Prediction") |
|
st.success(original_label) |
|
st.image(image, caption="Original Image", use_container_width=True) |
|
|
|
|
|
with col2: |
|
st.markdown("### Adversarial Prediction") |
|
st.error(adversarial_label) |
|
|
|
perturbed_image_resized = perturbed_image.resize((224, 224)) |
|
st.image(perturbed_image_resized, caption="Perturbed Image", use_container_width=True) |
|
else: |
|
st.warning("Please upload image.") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|