Spaces:
Sleeping
Sleeping
# Importing necessary libraries | |
import numpy as np # For handling arrays and numerical operations | |
import gradio as gr # For creating a simple web interface for interacting with the model | |
from PIL import Image # For image processing (like resizing) | |
from tensorflow import keras # For building and working with machine learning models | |
# Building the model using Keras Sequential API | |
model = keras.models.Sequential([ | |
keras.layers.Flatten(input_shape=(28, 28)), # This layer flattens the 28x28 pixel image into a 1D array of 784 values | |
keras.layers.Dense(512, activation='relu'), # This fully connected layer has 512 neurons with ReLU activation function | |
keras.layers.Dense(512, activation='relu'), # Another fully connected layer with 512 neurons and ReLU activation | |
keras.layers.Dense(10, activation='softmax') # The output layer has 10 neurons for the 10 digits (0-9) with softmax activation to convert raw scores into probabilities | |
]) | |
# Compiling the model with Adam optimizer and sparse categorical cross-entropy loss function | |
model.compile(optimizer=keras.optimizers.Adam(0.001), # Adam optimizer with learning rate of 0.001 | |
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), # Using cross-entropy loss for multi-class classification | |
metrics=[keras.metrics.SparseCategoricalAccuracy()]) # Measuring accuracy during training | |
# Loading pre-trained weights for the model | |
model.load_weights('./weights/mnist.weights.h5') # Loading the weights from a saved file | |
# Defining the function that will be used to classify the input image | |
def classify(input): | |
# Preprocessing the input image | |
# Convert the input image (which is in the format of a list of pixel values) to a numpy array | |
# Resize it to 28x28 pixels (if not already 28x28) | |
image = np.expand_dims(np.array(Image.fromarray(input['layers'][0]) # 'input' contains the image data | |
.resize((28, 28), resample=Image.Resampling.BILINEAR), dtype=int), axis=0) # Resizing to match model input size | |
# Predicting the digit from the processed image using the model | |
prediction = model.predict(image).tolist()[0] # Getting the output prediction as a list | |
# Returning the probabilities for each of the 10 digits (0 to 9) | |
return {str(i): float(prediction[i]) for i in range(10)} # Converting the predictions into a dictionary of probabilities | |
# Setting up the Gradio interface for user interaction | |
# The user will draw a digit on the sketchpad, which will be classified | |
input_sketchpad = gr.Paint(image_mode="L", brush=gr.components.image_editor.Brush(default_color="rgb(156, 104, 200)")) # The input is a paint canvas where the user can draw | |
output_label = gr.Label() # A label to display the predicted output (probabilities of each digit) | |
# Creating and launching the Gradio interface | |
gr.Interface(fn=classify, # The function that will handle the classification | |
inputs=input_sketchpad, # The input is the paint canvas where the user draws the digit | |
outputs=output_label, # The output will display the predicted label | |
flagging_mode='never', # Disable flagging for the interface | |
theme=gr.themes.Soft()).launch() # Using a soft theme for the interface and launching it | |