papasega commited on
Commit
36306c8
·
verified ·
1 Parent(s): 2b04118

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -44
app.py CHANGED
@@ -2,56 +2,108 @@ import tensorflow as tf
2
  import gradio as gr
3
  import numpy as np
4
  import cv2
 
 
5
 
6
- # 1. Load Model (Optimized load)
7
- model = tf.keras.models.load_model('TP_MNIST_CNN_model.h5')
8
-
9
- def predict_digit(input_data):
10
- """
11
- Pipeline de prédiction robuste :
12
- 1. Gestion du format d'entrée (Gradio 4 renvoie parfois un dict).
13
- 2. Resize vers 28x28 (Interpolation AREA pour préserver les traits).
14
- 3. Conversion Grayscale.
15
- 4. Inversion des couleurs (Adaptation domaine Humain -> Machine).
16
- 5. Normalisation et inférence.
17
- """
 
 
 
 
 
 
 
 
18
  if input_data is None:
19
  return None
20
 
21
- # Gradio 4 handle : input_data peut être un dictionnaire {'composite': ...}
22
- image = input_data["composite"] if isinstance(input_data, dict) else input_data
 
 
 
 
 
23
 
24
- # Pipeline OpenCV
25
- # Resize vers 28x28
26
- image = cv2.resize(image, (28, 28), interpolation=cv2.INTER_AREA)
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- # Convertir en niveaux de gris si nécessaire
29
- if len(image.shape) == 3:
30
- image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- # Inversion intelligente : Si l'image est majoritairement blanche (dessin noir sur fond blanc)
33
- # on inverse car MNIST a été entraîné sur blanc sur fond noir.
34
- if np.mean(image) > 127:
35
- image = 255 - image
36
 
37
- # Normalisation
38
- image = image / 255.0
39
 
40
- # Reshape (Batch, H, W, Channels)
41
- image = image.reshape(1, 28, 28, 1)
42
-
43
- # Inférence
44
- prediction = model.predict(image, verbose=0)
45
- return int(np.argmax(prediction))
46
-
47
- # Interface Gradio 4 Moderne
48
- iface = gr.Interface(
49
- fn=predict_digit,
50
- inputs=gr.Sketchpad(label="Dessinez un chiffre", type="numpy", crop_size=(28, 28)),
51
- outputs="label",
52
- title="Reconnaissance MNIST - Production Grade",
53
- description="CNN Model. Dessinez un chiffre au centre.",
54
- allow_flagging="never"
55
- )
56
-
57
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
  import numpy as np
4
  import cv2
5
+ from scipy.ndimage import center_of_mass
6
+ import math
7
 
8
+ # 1. Chargement du modèle
9
+ model = tf.keras.models.load_model('mnist_cnn_v1.keras')
10
+
11
+ def get_best_shift(img):
12
+ """Calcule le décalage optimal pour centrer l'image par centre de masse."""
13
+ cy, cx = center_of_mass(img)
14
+ rows, cols = img.shape
15
+ shiftx = np.round(cols/2.0-cx).astype(int)
16
+ shifty = np.round(rows/2.0-cy).astype(int)
17
+ return shiftx, shifty
18
+
19
+ def shift(img, sx, sy):
20
+ """Applique le décalage géométrique."""
21
+ rows, cols = img.shape
22
+ M = np.float32([[1, 0, sx], [0, 1, sy]])
23
+ shifted = cv2.warpAffine(img, M, (cols, rows))
24
+ return shifted
25
+
26
+ def preprocess_image(input_data):
27
+ """Pipeline robuste : Resize -> Gray -> Invert -> Center -> Normalize"""
28
  if input_data is None:
29
  return None
30
 
31
+ # Gestion format Gradio 4 (Dict ou Array)
32
+ img = input_data["composite"] if isinstance(input_data, dict) else input_data
33
+
34
+ # 1. Resize initial + Grayscale
35
+ img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_AREA)
36
+ if len(img.shape) == 3:
37
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
38
 
39
+ # 2. Inversion (Noir sur Blanc -> Blanc sur Noir)
40
+ # MNIST est blanc sur noir. Si l'utilisateur dessine en noir, on inverse.
41
+ if np.mean(img) > 127:
42
+ img = 255 - img
43
+
44
+ # 3. Nettoyage du bruit (Thresholding)
45
+ (_, img) = cv2.threshold(img, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
46
+
47
+ # 4. Centrage par Centre de Masse (CRITIQUE pour la précision)
48
+ # On ajoute une marge noire temporaire pour éviter de couper le chiffre lors du shift
49
+ while np.sum(img[0]) == 0: img = img[1:]
50
+ while np.sum(img[:,0]) == 0: img = img[:,1:]
51
+ while np.sum(img[-1]) == 0: img = img[:-1]
52
+ while np.sum(img[:,-1]) == 0: img = img[:,:-1]
53
 
54
+ rows, cols = img.shape
55
+ if rows > cols:
56
+ factor = 20.0/rows
57
+ rows = 20
58
+ cols = int(round(cols*factor))
59
+ img = cv2.resize(img, (cols, rows))
60
+ else:
61
+ factor = 20.0/cols
62
+ cols = 20
63
+ rows = int(round(rows*factor))
64
+ img = cv2.resize(img, (cols, rows))
65
+
66
+ colsPadding = (int(math.ceil((28-cols)/2.0)), int(math.floor((28-cols)/2.0)))
67
+ rowsPadding = (int(math.ceil((28-rows)/2.0)), int(math.floor((28-rows)/2.0)))
68
+ img = np.lib.pad(img, (rowsPadding, colsPadding), 'constant')
69
+
70
+ shiftx, shifty = get_best_shift(img)
71
+ shifted = shift(img, shiftx, shifty)
72
+ img = shifted
73
 
74
+ # 5. Normalisation et Reshape final
75
+ img = img / 255.0
76
+ img = img.reshape(1, 28, 28, 1)
77
+ return img
78
 
79
+ def predict(image):
80
+ if image is None: return None
81
 
82
+ processed_img = preprocess_image(image)
83
+ prediction = model.predict(processed_img, verbose=0)[0]
84
+
85
+ # Retourne un dictionnaire {Label: Confiance} pour Gradio
86
+ return {str(i): float(prediction[i]) for i in range(10)}
87
+
88
+ # --- UI Moderne avec Gradio Blocks ---
89
+ css = """
90
+ .container {max-width: 800px; margin: auto; padding-top: 20px}
91
+ #title {text-align: center; font-size: 2em; font-weight: bold; margin-bottom: 20px}
92
+ """
93
+
94
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
95
+ gr.Markdown("# 🧠 MNIST Classifier: From Sketch to Prediction", elem_id="title")
96
+ gr.Markdown("Dessinez un chiffre (0-9). L'IA le centrera automatiquement avant l'analyse.")
97
+
98
+ with gr.Row():
99
+ with gr.Column():
100
+ input_sketch = gr.Sketchpad(label="Dessinez ici", type="numpy", crop_size=(200, 200))
101
+ predict_btn = gr.Button("Analyser", variant="primary")
102
+
103
+ with gr.Column():
104
+ label_output = gr.Label(num_top_classes=3, label="Prédictions & Probabilités")
105
+
106
+ predict_btn.click(fn=predict, inputs=input_sketch, outputs=label_output)
107
+
108
+ if __name__ == "__main__":
109
+ demo.launch()