File size: 5,105 Bytes
611e872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import streamlit as st
import numpy as np
from PIL import Image
from tensorflow.keras.models import load_model
import requests
from io import BytesIO

# Load the pre-trained chest X-ray classification model
try:
    model = load_model('chest_xray_model.h5')
except Exception as e:
    st.error("Error loading model")
    st.stop()

# Define the class names for prediction output
class_names = ['Normal', 'Pneumonia']

# Prediction class for encapsulating the prediction logic
class Prediction:
    def __init__(self, model):
        self.model = model

    def classify_image(self, image):
        try:
            image = Image.fromarray(image).convert('RGB')
            image = image.resize((512, 512))
            image_array = np.array(image).astype(np.float32) / 255.0
            image_array = np.expand_dims(image_array, axis=0)

            predictions = self.model.predict(image_array)[0]
            predicted_class_idx = np.argmax(predictions)
            predicted_class = class_names[predicted_class_idx]
            predicted_confidence = predictions[predicted_class_idx] * 100

            return predicted_class, predicted_confidence, predictions
        except Exception as e:
            st.error("Error during classification")
            return None, None, None

# Initialize the Prediction class
predictor = Prediction(model)

# Streamlit app layout
st.title("📊 Chest X-Ray Classification")
st.markdown("""
    Upload one or more chest X-ray images or provide an image URL, and the model will classify each as either **Normal** or **Pneumonia**.
""")

# Input option selection
input_option = st.radio("Choose how to upload the image(s):", ("Upload Image(s)", "Image URL"))

# Initialize images list
images = []

# Patient name input
patient_name = st.text_input("### Patient Name")

if input_option == "Upload Image(s)":
    uploaded_images = st.file_uploader("### Step 1: Upload Your Chest X-Ray Image(s)", type=["jpg", "jpeg", "png"],
                                       accept_multiple_files=True)
    if uploaded_images:
        for uploaded_image in uploaded_images:
            try:
                image = np.array(Image.open(uploaded_image))
                images.append((image, uploaded_image.name))
            except Exception as e:
                st.error("Error loading image")

elif input_option == "Image URL":
    image_url = st.text_input("### Step 1: Enter the Image URL")
    if image_url:
        try:
            response = requests.get(image_url)
            if response.status_code == 200:
                images.append((np.array(Image.open(BytesIO(response.content))), image_url))
                st.markdown(f"[Image URL]({image_url})")
            else:
                st.error("Error fetching image from URL: Unable to retrieve the image.")
        except Exception as e:
            st.error("Error fetching image from URL")

# Store classification results
results = []

if images:
    submit_button = st.button("Submit", key="submit")
    if submit_button:
        st.write("### Step 2: Review the Uploaded Image(s) and Results")

        for idx, (image, image_name) in enumerate(images):
            patient_display_name = f"[{patient_name}](#{image_name})"
            st.write(f"#### Patient: {patient_display_name}")

            col1, col2 = st.columns([2, 1])
            with col1:
                st.image(image, caption=image_name, use_column_width=True, clamp=True)

            with col2:
                st.subheader("Prediction Results")
                with st.spinner("Processing..."):
                    predicted_class, predicted_confidence, predictions = predictor.classify_image(image)

                if predicted_class is not None:
                    st.markdown(f"""
                    <div style="border: 2px solid #2196F3; padding: 10px; border-radius: 5px;">
                        <p style="font-size: 16px; font-weight: bold;">Predicted Class: {predicted_class}</p>
                        <p style="font-size: 16px;">Confidence: {predicted_confidence:.2f}%</p>
                        <p style="font-size: 16px; font-weight: bold;">Class Confidence Levels:</p>
                        <ul style="list-style-type: none; padding: 0;">
                            <li style="color: #4CAF50;">Normal: {predictions[0] * 100:.1f}%</li>
                            <li style="color: #F44336;">Pneumonia: {predictions[1] * 100:.1f}%</li>
                        </ul>
                    </div>
                    """, unsafe_allow_html=True)

                    results.append((idx + 1, patient_name, predicted_class, predicted_confidence, predictions))

                    st.markdown("---")

        # Button to manually start a new session
        if st.button("Start New Session"):
            images.clear()
            st.experimental_rerun()  # Rerun the app to refresh the state

st.write("### Additional Information")
st.markdown("""
    - This model is trained to differentiate between normal and pneumonia-affected chest X-rays.
    - Confidence levels are displayed as a percentage for each class.
""")