Files changed (1) hide show
  1. app.py +122 -100
app.py CHANGED
@@ -1,100 +1,122 @@
1
- import streamlit as st
2
-
3
- import tensorflow as tf
4
- from PIL import Image
5
- import numpy as np
6
- import cv2
7
- from huggingface_hub import from_pretrained_keras
8
-
9
-
10
- try:
11
- model=from_pretrained_keras("SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net")
12
- except:
13
- model=tf.keras.models.load_model("dental_xray_seg.h5")
14
- pass
15
-
16
- st.header("Segmentation of Teeth in Panoramic X-ray Image Using UNet")
17
-
18
- examples=["107.png","108.png","109.png"]
19
- link='Check Out Our Github Repo ! [link](https://github.com/SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net)'
20
- st.markdown(link,unsafe_allow_html=True)
21
-
22
-
23
- def load_image(image_file):
24
- img = Image.open(image_file)
25
- return img
26
-
27
- def convert_one_channel(img):
28
- #some images have 3 channels , although they are grayscale image
29
- if len(img.shape)>2:
30
- img= cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
31
- return img
32
- else:
33
- return img
34
-
35
- def convert_rgb(img):
36
- #some images have 3 channels , although they are grayscale image
37
- if len(img.shape)==2:
38
- img= cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
39
- return img
40
- else:
41
- return img
42
-
43
-
44
- st.subheader("Upload Dental Panoramic X-ray Image Image")
45
- image_file = st.file_uploader("Upload Images", type=["png","jpg","jpeg"])
46
-
47
-
48
- col1, col2, col3 = st.columns(3)
49
- with col1:
50
- ex=load_image(examples[0])
51
- st.image(ex,width=200)
52
- if st.button('Example 1'):
53
- image_file=examples[0]
54
-
55
- with col2:
56
- ex1=load_image(examples[1])
57
- st.image(ex1,width=200)
58
- if st.button('Example 2'):
59
- image_file=examples[1]
60
-
61
-
62
- with col3:
63
- ex2=load_image(examples[2])
64
- st.image(ex2,width=200)
65
- if st.button('Example 3'):
66
- image_file=examples[2]
67
-
68
-
69
- if image_file is not None:
70
-
71
- img=load_image(image_file)
72
-
73
- st.text("Making A Prediction ....")
74
- st.image(img,width=850)
75
-
76
- img=np.asarray(img)
77
-
78
- img_cv=convert_one_channel(img)
79
- img_cv=cv2.resize(img_cv,(512,512), interpolation=cv2.INTER_LANCZOS4)
80
- img_cv=np.float32(img_cv/255)
81
-
82
- img_cv=np.reshape(img_cv,(1,512,512,1))
83
- prediction=model.predict(img_cv)
84
- predicted=prediction[0]
85
- predicted = cv2.resize(predicted, (img.shape[1],img.shape[0]), interpolation=cv2.INTER_LANCZOS4)
86
- mask=np.uint8(predicted*255)#
87
- _, mask = cv2.threshold(mask, thresh=0, maxval=255, type=cv2.THRESH_BINARY+cv2.THRESH_OTSU)
88
- kernel =( np.ones((5,5), dtype=np.float32))
89
- mask=cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel,iterations=1 )
90
- mask=cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel,iterations=1 )
91
- cnts,hieararch=cv2.findContours(mask,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
92
- output = cv2.drawContours(convert_rgb(img), cnts, -1, (255, 0, 0) , 3)
93
-
94
-
95
- if output is not None :
96
- st.subheader("Predicted Image")
97
- st.write(output.shape)
98
- st.image(output,width=850)
99
-
100
- st.text("DONE ! ....")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ from PIL import Image
4
+ import numpy as np
5
+ import cv2
6
+ from huggingface_hub import from_pretrained_keras
7
+
8
+ # Use st.cache_resource to load the model only once, preventing memory errors.
9
+ @st.cache_resource
10
+ def load_keras_model():
11
+ """Load the pre-trained Keras model from Hugging Face Hub and cache it."""
12
+ try:
13
+ # The model will be downloaded from the Hub and cached.
14
+ model = from_pretrained_keras("SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net")
15
+ return model
16
+ except Exception as e:
17
+ # If model loading fails, show an error and return None.
18
+ st.error(f"Error loading the model: {e}")
19
+ return None
20
+
21
+ # --- Helper Functions ---
22
+ def load_image(image_file):
23
+ """Loads an image from a file path or uploaded file object."""
24
+ img = Image.open(image_file)
25
+ return img
26
+
27
+ def convert_one_channel(img_array):
28
+ """Ensure the image is single-channel (grayscale)."""
29
+ # If image has 3 channels (like BGR or RGB), convert to grayscale.
30
+ if len(img_array.shape) > 2 and img_array.shape[2] > 1:
31
+ img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2GRAY)
32
+ return img_array
33
+
34
+ def convert_rgb(img_array):
35
+ """Ensure the image is 3-channel (RGB) for drawing contours."""
36
+ # If image is grayscale, convert to RGB to draw colored contours.
37
+ if len(img_array.shape) == 2:
38
+ img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
39
+ return img_array
40
+
41
+ # --- Streamlit App Layout ---
42
+ st.header("Segmentation of Teeth in Panoramic X-ray Image Using UNet")
43
+
44
+ link = 'Check Out Our Github Repo! [link](https://github.com/SerdarHelli/Segmentation-of-Teeth-in-Panoramic-X-ray-Image-Using-U-Net)'
45
+ st.markdown(link, unsafe_allow_html=True)
46
+
47
+ # Load the model and stop the app if it fails
48
+ model = load_keras_model()
49
+ if model is None:
50
+ st.warning("Model could not be loaded. The application cannot proceed.")
51
+ st.stop()
52
+
53
+ # --- Image Selection Section ---
54
+ st.subheader("Upload a Dental Panoramic X-ray Image or Select an Example")
55
+ image_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])
56
+
57
+ st.write("---")
58
+ st.write("Or choose an example:")
59
+ examples = ["107.png", "108.png", "109.png"]
60
+ col1, col2, col3 = st.columns(3)
61
+
62
+ # Display example images and buttons to use them
63
+ with col1:
64
+ st.image(examples[0], caption='Example 1', use_column_width=True)
65
+ if st.button('Use Example 1'):
66
+ image_file = examples[0]
67
+
68
+ with col2:
69
+ st.image(examples[1], caption='Example 2', use_column_width=True)
70
+ if st.button('Use Example 2'):
71
+ image_file = examples[1]
72
+
73
+ with col3:
74
+ st.image(examples[2], caption='Example 3', use_column_width=True)
75
+ if st.button('Use Example 3'):
76
+ image_file = examples[2]
77
+
78
+ # --- Processing and Prediction Section ---
79
+ if image_file is not None:
80
+ st.write("---")
81
+
82
+ # Load and display the selected image
83
+ original_pil_img = load_image(image_file)
84
+ st.image(original_pil_img, caption="Original Image", use_column_width=True)
85
+
86
+ with st.spinner("Analyzing image and predicting segmentation..."):
87
+ # Convert PIL image to NumPy array for processing
88
+ original_np_img = np.array(original_pil_img)
89
+
90
+ # 1. Pre-process for the model
91
+ img_gray = convert_one_channel(original_np_img.copy())
92
+ img_resized = cv2.resize(img_gray, (512, 512), interpolation=cv2.INTER_LANCZOS4)
93
+ img_normalized = np.float32(img_resized / 255.0)
94
+ img_input = np.reshape(img_normalized, (1, 512, 512, 1))
95
+
96
+ # 2. Make prediction
97
+ prediction = model.predict(img_input)
98
+
99
+ # 3. Post-process the prediction mask
100
+ predicted_mask = prediction[0]
101
+ resized_mask = cv2.resize(predicted_mask, (original_np_img.shape[1], original_np_img.shape[0]), interpolation=cv2.INTER_LANCZOS4)
102
+
103
+ # Binarize the mask using Otsu's thresholding
104
+ mask_8bit = (resized_mask * 255).astype(np.uint8)
105
+ _, final_mask = cv2.threshold(mask_8bit, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
106
+
107
+ # Clean up mask with morphological operations
108
+ kernel = np.ones((5, 5), dtype=np.uint8)
109
+ final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_OPEN, kernel, iterations=1)
110
+ final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_CLOSE, kernel, iterations=1)
111
+
112
+ # Find contours on the final mask
113
+ contours, _ = cv2.findContours(final_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
114
+
115
+ # Draw contours on a color version of the original image
116
+ img_for_drawing = convert_rgb(original_np_img.copy())
117
+ output_image = cv2.drawContours(img_for_drawing, contours, -1, (255, 0, 0), 3) # Draw red contours
118
+
119
+ st.subheader("Predicted Segmentation")
120
+ st.image(output_image, caption="Image with Segmented Teeth", use_column_width=True)
121
+
122
+ st.success("Prediction complete!")