import gradio as gr import numpy as np import pickle from PIL import Image import os import random # 1. Load the model with open('model.pkl', 'rb') as f: model_params = pickle.load(f) W1 = model_params['W1'] b1 = model_params['b1'] W2 = model_params['W2'] b2 = model_params['b2'] # 2. Define helper functions def ReLu(Z): return np.maximum(Z, 0) def softmax(Z): return np.exp(Z) / sum(np.exp(Z)) def forward_prop(W1, b1, W2, b2, X): Z1 = W1.dot(X) + b1 A1 = ReLu(Z1) Z2 = W2.dot(A1) + b2 A2 = softmax(Z2) return Z1, Z2, A1, A2 def get_predictions(A2): return np.argmax(A2, 0) def preprocess_image(image): # Convert to grayscale img = image.convert('L') # Resize the image img = img.resize((28, 28)) # Convert to numpy array and normalize img_array = np.array(img).reshape(1, 28*28) / 255.0 return img_array.T # Transpose to match the shape (784, 1) # 3. Define prediction function def predict_digit(image): X = preprocess_image(image) # Forward propagation _, _, _, A2 = forward_prop(W1, b1, W2, b2, X) # Get the prediction prediction = get_predictions(A2) return int(prediction[0]) # 4. Load sample images sample_images = [] sample_dir = "sample_images" # Make sure this directory exists in your Space for filename in os.listdir(sample_dir): if filename.endswith((".png", ".jpg", ".jpeg")): img_path = os.path.join(sample_dir, filename) sample_images.append(img_path) # 5. Define function to select random image def select_random_image(): return random.choice(sample_images) # 6. Create Gradio interface with gr.Blocks() as demo: gr.Markdown("# Handwritten Digit Recognition") with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Input Image") upload_button = gr.UploadButton("Upload Image", file_types=["image"]) sample_button = gr.Button("Use Random Sample Image") with gr.Column(): output_label = gr.Label(label="Prediction") predict_button = gr.Button("Predict") upload_button.upload(fn=lambda file: file.name, inputs=upload_button, outputs=input_image) sample_button.click(fn=select_random_image, inputs=None, outputs=input_image) predict_button.click(fn=predict_digit, inputs=input_image, outputs=output_label) gr.Markdown("## Sample Images") with gr.Row(): for img_path in sample_images[:5]: # Display first 5 sample images gr.Image(img_path, show_label=False, height=100) with gr.Row(): for img_path in sample_images[5:10]: # Display next 5 sample images gr.Image(img_path, show_label=False, height=100) # 7. Launch the app demo.launch()