Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import streamlit as st
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
from torchvision import transforms
|
7 |
+
from transformers import AutoModelForImageClassification
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
# Load your model
|
11 |
+
@st.cache_data
|
12 |
+
def load_dataset():
|
13 |
+
dataset_path = "./Data_Entry_2017_v2020.csv" # Replace with your dataset path
|
14 |
+
return pd.read_csv(dataset_path)
|
15 |
+
|
16 |
+
data = load_dataset()
|
17 |
+
|
18 |
+
@st.cache_resource
|
19 |
+
def load_model():
|
20 |
+
# Define the model architecture
|
21 |
+
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=15)
|
22 |
+
# Load the saved state dictionary
|
23 |
+
state_dict = torch.load("best_model_new_retrain.pth", map_location=torch.device('cpu'))
|
24 |
+
model.load_state_dict(state_dict)
|
25 |
+
model.eval()
|
26 |
+
return model
|
27 |
+
|
28 |
+
model = load_model()
|
29 |
+
|
30 |
+
# Define image transformation
|
31 |
+
transform = transforms.Compose([
|
32 |
+
transforms.Resize((224, 224)), # Adjust based on your model's requirements
|
33 |
+
transforms.ToTensor(),
|
34 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet stats
|
35 |
+
])
|
36 |
+
|
37 |
+
# Function to make predictions
|
38 |
+
def predict_image(image):
|
39 |
+
image = transform(image).unsqueeze(0) # Add batch dimension
|
40 |
+
with torch.no_grad():
|
41 |
+
outputs = model(image).logits
|
42 |
+
probabilities = torch.sigmoid(outputs)
|
43 |
+
return probabilities
|
44 |
+
|
45 |
+
# Streamlit App
|
46 |
+
st.title("Chest Xray Disease Prediction App")
|
47 |
+
st.write("Upload single or multiple images to get predictions.")
|
48 |
+
|
49 |
+
# File uploader for single or bulk images
|
50 |
+
uploaded_files = st.file_uploader("Upload Image(s)", type=["jpg", "png", "jpeg"], accept_multiple_files=True)
|
51 |
+
|
52 |
+
# Process each uploaded file
|
53 |
+
if uploaded_files:
|
54 |
+
for uploaded_file in uploaded_files:
|
55 |
+
# Load and display the image
|
56 |
+
image = Image.open(uploaded_file).convert("RGB")
|
57 |
+
st.image(image, caption=f"Uploaded Image: {uploaded_file.name}", use_column_width=True)
|
58 |
+
|
59 |
+
# Search for the filename in the dataset
|
60 |
+
uploaded_filename = uploaded_file.name
|
61 |
+
matching_row = data[data['Image Index'] == uploaded_filename]
|
62 |
+
truth = matching_row.iloc[0]['Finding Labels'] if not matching_row.empty else "No matching label found"
|
63 |
+
|
64 |
+
st.write(f"**Truth (Ground Truth Labels):** {truth}")
|
65 |
+
|
66 |
+
# Get predictions
|
67 |
+
probabilities = predict_image(image)
|
68 |
+
|
69 |
+
# Create a DataFrame to display probabilities
|
70 |
+
label_columns = [
|
71 |
+
'No Finding', 'Infiltration', 'Effusion', 'Atelectasis', 'Nodule',
|
72 |
+
'Mass', 'Pneumothorax', 'Consolidation', 'Pleural_Thickening',
|
73 |
+
'Cardiomegaly', 'Emphysema', 'Edema', 'Fibrosis', 'Pneumonia', 'Hernia'
|
74 |
+
]
|
75 |
+
prediction_df = pd.DataFrame({
|
76 |
+
"Class": label_columns,
|
77 |
+
"Probability": probabilities.squeeze().tolist()
|
78 |
+
})
|
79 |
+
|
80 |
+
# Highlight the highest probabilities (you can customize the threshold)
|
81 |
+
prediction_df['Highlight'] = prediction_df['Probability'] > 0.5
|
82 |
+
|
83 |
+
# Display predictions
|
84 |
+
st.write("**Prediction (Model Probabilities):**")
|
85 |
+
st.dataframe(
|
86 |
+
prediction_df.style.format({"Probability": "{:.2f}"}).applymap(
|
87 |
+
lambda val: 'background-color: yellow;' if val else '', subset=['Highlight']
|
88 |
+
)
|
89 |
+
)
|
90 |
+
|