Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,56 +2,108 @@ import tensorflow as tf
|
|
| 2 |
import gradio as gr
|
| 3 |
import numpy as np
|
| 4 |
import cv2
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
# 1.
|
| 7 |
-
model = tf.keras.models.load_model('
|
| 8 |
-
|
| 9 |
-
def
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
if input_data is None:
|
| 19 |
return None
|
| 20 |
|
| 21 |
-
# Gradio 4
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
#
|
| 25 |
-
#
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
|
| 29 |
-
if
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
#
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
|
| 37 |
-
|
| 38 |
-
image
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
#
|
| 44 |
-
prediction
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
import numpy as np
|
| 4 |
import cv2
|
| 5 |
+
from scipy.ndimage import center_of_mass
|
| 6 |
+
import math
|
| 7 |
|
| 8 |
+
# 1. Chargement du modèle
|
| 9 |
+
model = tf.keras.models.load_model('mnist_cnn_v1.keras')
|
| 10 |
+
|
| 11 |
+
def get_best_shift(img):
|
| 12 |
+
"""Calcule le décalage optimal pour centrer l'image par centre de masse."""
|
| 13 |
+
cy, cx = center_of_mass(img)
|
| 14 |
+
rows, cols = img.shape
|
| 15 |
+
shiftx = np.round(cols/2.0-cx).astype(int)
|
| 16 |
+
shifty = np.round(rows/2.0-cy).astype(int)
|
| 17 |
+
return shiftx, shifty
|
| 18 |
+
|
| 19 |
+
def shift(img, sx, sy):
|
| 20 |
+
"""Applique le décalage géométrique."""
|
| 21 |
+
rows, cols = img.shape
|
| 22 |
+
M = np.float32([[1, 0, sx], [0, 1, sy]])
|
| 23 |
+
shifted = cv2.warpAffine(img, M, (cols, rows))
|
| 24 |
+
return shifted
|
| 25 |
+
|
| 26 |
+
def preprocess_image(input_data):
|
| 27 |
+
"""Pipeline robuste : Resize -> Gray -> Invert -> Center -> Normalize"""
|
| 28 |
if input_data is None:
|
| 29 |
return None
|
| 30 |
|
| 31 |
+
# Gestion format Gradio 4 (Dict ou Array)
|
| 32 |
+
img = input_data["composite"] if isinstance(input_data, dict) else input_data
|
| 33 |
+
|
| 34 |
+
# 1. Resize initial + Grayscale
|
| 35 |
+
img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_AREA)
|
| 36 |
+
if len(img.shape) == 3:
|
| 37 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
| 38 |
|
| 39 |
+
# 2. Inversion (Noir sur Blanc -> Blanc sur Noir)
|
| 40 |
+
# MNIST est blanc sur noir. Si l'utilisateur dessine en noir, on inverse.
|
| 41 |
+
if np.mean(img) > 127:
|
| 42 |
+
img = 255 - img
|
| 43 |
+
|
| 44 |
+
# 3. Nettoyage du bruit (Thresholding)
|
| 45 |
+
(_, img) = cv2.threshold(img, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
|
| 46 |
+
|
| 47 |
+
# 4. Centrage par Centre de Masse (CRITIQUE pour la précision)
|
| 48 |
+
# On ajoute une marge noire temporaire pour éviter de couper le chiffre lors du shift
|
| 49 |
+
while np.sum(img[0]) == 0: img = img[1:]
|
| 50 |
+
while np.sum(img[:,0]) == 0: img = img[:,1:]
|
| 51 |
+
while np.sum(img[-1]) == 0: img = img[:-1]
|
| 52 |
+
while np.sum(img[:,-1]) == 0: img = img[:,:-1]
|
| 53 |
|
| 54 |
+
rows, cols = img.shape
|
| 55 |
+
if rows > cols:
|
| 56 |
+
factor = 20.0/rows
|
| 57 |
+
rows = 20
|
| 58 |
+
cols = int(round(cols*factor))
|
| 59 |
+
img = cv2.resize(img, (cols, rows))
|
| 60 |
+
else:
|
| 61 |
+
factor = 20.0/cols
|
| 62 |
+
cols = 20
|
| 63 |
+
rows = int(round(rows*factor))
|
| 64 |
+
img = cv2.resize(img, (cols, rows))
|
| 65 |
+
|
| 66 |
+
colsPadding = (int(math.ceil((28-cols)/2.0)), int(math.floor((28-cols)/2.0)))
|
| 67 |
+
rowsPadding = (int(math.ceil((28-rows)/2.0)), int(math.floor((28-rows)/2.0)))
|
| 68 |
+
img = np.lib.pad(img, (rowsPadding, colsPadding), 'constant')
|
| 69 |
+
|
| 70 |
+
shiftx, shifty = get_best_shift(img)
|
| 71 |
+
shifted = shift(img, shiftx, shifty)
|
| 72 |
+
img = shifted
|
| 73 |
|
| 74 |
+
# 5. Normalisation et Reshape final
|
| 75 |
+
img = img / 255.0
|
| 76 |
+
img = img.reshape(1, 28, 28, 1)
|
| 77 |
+
return img
|
| 78 |
|
| 79 |
+
def predict(image):
|
| 80 |
+
if image is None: return None
|
| 81 |
|
| 82 |
+
processed_img = preprocess_image(image)
|
| 83 |
+
prediction = model.predict(processed_img, verbose=0)[0]
|
| 84 |
+
|
| 85 |
+
# Retourne un dictionnaire {Label: Confiance} pour Gradio
|
| 86 |
+
return {str(i): float(prediction[i]) for i in range(10)}
|
| 87 |
+
|
| 88 |
+
# --- UI Moderne avec Gradio Blocks ---
|
| 89 |
+
css = """
|
| 90 |
+
.container {max-width: 800px; margin: auto; padding-top: 20px}
|
| 91 |
+
#title {text-align: center; font-size: 2em; font-weight: bold; margin-bottom: 20px}
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
| 95 |
+
gr.Markdown("# 🧠 MNIST Classifier: From Sketch to Prediction", elem_id="title")
|
| 96 |
+
gr.Markdown("Dessinez un chiffre (0-9). L'IA le centrera automatiquement avant l'analyse.")
|
| 97 |
+
|
| 98 |
+
with gr.Row():
|
| 99 |
+
with gr.Column():
|
| 100 |
+
input_sketch = gr.Sketchpad(label="Dessinez ici", type="numpy", crop_size=(200, 200))
|
| 101 |
+
predict_btn = gr.Button("Analyser", variant="primary")
|
| 102 |
+
|
| 103 |
+
with gr.Column():
|
| 104 |
+
label_output = gr.Label(num_top_classes=3, label="Prédictions & Probabilités")
|
| 105 |
+
|
| 106 |
+
predict_btn.click(fn=predict, inputs=input_sketch, outputs=label_output)
|
| 107 |
+
|
| 108 |
+
if __name__ == "__main__":
|
| 109 |
+
demo.launch()
|