kavithapadala commited on
Commit
463a6b5
·
verified ·
1 Parent(s): 9cb8b00

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import streamlit as st
4
+ from PIL import Image
5
+ import torch
6
+ from torchvision import transforms
7
+ from transformers import AutoModelForImageClassification
8
+ import pandas as pd
9
+
10
+ # Load your model
11
+ @st.cache_data
12
+ def load_dataset():
13
+ dataset_path = "./Data_Entry_2017_v2020.csv" # Replace with your dataset path
14
+ return pd.read_csv(dataset_path)
15
+
16
+ data = load_dataset()
17
+
18
+ @st.cache_resource
19
+ def load_model():
20
+ # Define the model architecture
21
+ model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=15)
22
+ # Load the saved state dictionary
23
+ state_dict = torch.load("best_model_new_retrain.pth", map_location=torch.device('cpu'))
24
+ model.load_state_dict(state_dict)
25
+ model.eval()
26
+ return model
27
+
28
+ model = load_model()
29
+
30
+ # Define image transformation
31
+ transform = transforms.Compose([
32
+ transforms.Resize((224, 224)), # Adjust based on your model's requirements
33
+ transforms.ToTensor(),
34
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet stats
35
+ ])
36
+
37
+ # Function to make predictions
38
+ def predict_image(image):
39
+ image = transform(image).unsqueeze(0) # Add batch dimension
40
+ with torch.no_grad():
41
+ outputs = model(image).logits
42
+ probabilities = torch.sigmoid(outputs)
43
+ return probabilities
44
+
45
+ # Streamlit App
46
+ st.title("Chest Xray Disease Prediction App")
47
+ st.write("Upload single or multiple images to get predictions.")
48
+
49
+ # File uploader for single or bulk images
50
+ uploaded_files = st.file_uploader("Upload Image(s)", type=["jpg", "png", "jpeg"], accept_multiple_files=True)
51
+
52
+ # Process each uploaded file
53
+ if uploaded_files:
54
+ for uploaded_file in uploaded_files:
55
+ # Load and display the image
56
+ image = Image.open(uploaded_file).convert("RGB")
57
+ st.image(image, caption=f"Uploaded Image: {uploaded_file.name}", use_column_width=True)
58
+
59
+ # Search for the filename in the dataset
60
+ uploaded_filename = uploaded_file.name
61
+ matching_row = data[data['Image Index'] == uploaded_filename]
62
+ truth = matching_row.iloc[0]['Finding Labels'] if not matching_row.empty else "No matching label found"
63
+
64
+ st.write(f"**Truth (Ground Truth Labels):** {truth}")
65
+
66
+ # Get predictions
67
+ probabilities = predict_image(image)
68
+
69
+ # Create a DataFrame to display probabilities
70
+ label_columns = [
71
+ 'No Finding', 'Infiltration', 'Effusion', 'Atelectasis', 'Nodule',
72
+ 'Mass', 'Pneumothorax', 'Consolidation', 'Pleural_Thickening',
73
+ 'Cardiomegaly', 'Emphysema', 'Edema', 'Fibrosis', 'Pneumonia', 'Hernia'
74
+ ]
75
+ prediction_df = pd.DataFrame({
76
+ "Class": label_columns,
77
+ "Probability": probabilities.squeeze().tolist()
78
+ })
79
+
80
+ # Highlight the highest probabilities (you can customize the threshold)
81
+ prediction_df['Highlight'] = prediction_df['Probability'] > 0.5
82
+
83
+ # Display predictions
84
+ st.write("**Prediction (Model Probabilities):**")
85
+ st.dataframe(
86
+ prediction_df.style.format({"Probability": "{:.2f}"}).applymap(
87
+ lambda val: 'background-color: yellow;' if val else '', subset=['Highlight']
88
+ )
89
+ )
90
+