Kuzushiji 49 MNIST FCN Model

Overview

This repository contains a Fully Convolutional Neural Network (FCN) model for the Kuzushiji 49 MNIST dataset. The Kuzushiji 49 MNIST dataset is focused on image classification of Japanese hiragana characters, specifically targeting 49 different classes.

Model Architecture

The model is based on a Fully Convolutional Neural Network (FCN) architecture designed for image classification tasks. It has been trained and fine-tuned on the Kuzushiji 49 MNIST dataset to achieve accurate predictions for the specified character classes.

Usage

To use the model in your projects, you must create class of model first before load model weight. Here's a quick example:

import torch
import torch.nn as nn
import torch.functional as F

# CNN Block
class ConvBlock(nn.Module):
    def __init__(self, c0, c1, k, s, p) -> None:
        super(ConvBlock, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(c0, c1, k, s, p),
            nn.BatchNorm2d(c1),
            nn.GELU(),
        )
    def forward(self, x):
        return self.net(x)

    # The filter weight of each layer is a Gaussian distribution with zero mean and standard deviation initialized by random extraction 0.001 (deviation is 0).
    def _initialize_weights(model):
        """
        Initializes weights of all layers in a PyTorch model.

        Args:
            model (nn.Module): The model to initialize weights for.
        """
        for m in model.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                
# FCN Model
class FullyConv(nn.Module):
    def __init__(self, num_classes) -> None:
        super(FullyConv, self).__init__()
        # Input 28x28
        self.net = nn.Sequential(
            ConvBlock(3, 16, 3, 2, 1),   # 14x14
            ConvBlock(16, 64, 3, 2, 1),  # 7x7
            ConvBlock(64, 128, 3, 2, 1), # 4x4
            ConvBlock(128, 256, 3, 2, 1), # 2x2
            nn.Dropout(p=0.5, inplace=True),
            nn.Conv2d(256, num_classes, 3, 2, 1), # 1x1
            nn.Flatten()
        )
    
        # Initialize model weights.
        self._initialize_weights()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


    # The filter weight of each layer is a Gaussian distribution with zero mean and standard deviation initialized by random extraction 0.001 (deviation is 0).
    def _initialize_weights(model):
        """
        Initializes weights of all layers in a PyTorch model.

        Args:
            model (nn.Module): The model to initialize weights for.
        """
        for m in model.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

# Init variable
IMG_SIZE = 28
transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.Grayscale(3),
        # transforms.ToDtype(torch.float32, scale=True),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

class_names = ['ใ‚', 'ใ„', 'ใ†', 'ใˆ', 'ใŠ',   # A
               'ใ‹', 'ใ', 'ใ', 'ใ‘', 'ใ“',   # Ka
               'ใ•', 'ใ—', 'ใ™', 'ใ›', 'ใ',   # Sa
               'ใŸ', 'ใก', 'ใค', 'ใฆ', 'ใจ',   # Ta
               'ใช', 'ใซ', 'ใฌ', 'ใญ', 'ใฎ',   # Na
               'ใฏ', 'ใฒ', 'ใต', 'ใธ', 'ใป',   # Ha
               'ใพ', 'ใฟ', 'ใ‚€', 'ใ‚', 'ใ‚‚',   # Ma
               'ใ‚„', 'ใ‚†', 'ใ‚ˆ',               # Ya 
               'ใ‚‰', 'ใ‚Š', 'ใ‚‹', 'ใ‚Œ', 'ใ‚',   # Ra
               'ใ‚', 'ใ‚', 'ใ‚‘',               # Wa, ?, ?
               'ใ‚’', 'ใ‚“', 'ใ‚']               # Wo, N, ?

# Create model and load weight
model = FullyConv(len(class_names))
model = model.from_pretrained("Hendrico/kmnist49-classifier")

# Predict function
def predict(model, img, transform, class_names):
    if type(img) == str:
        img = Image.open(img).convert('RGB')
    # img = PIL.ImageOps.invert(img)
    inputs = transform(img).unsqueeze(0)
    out = model(inputs)
    act_out = F.softmax(out)
    prob, pred = act_out.max(axis=1)
    plt.title(f"{class_names[pred.item()]} ({prob.item()*100:.2f}%)")
    plt.imshow(np.transpose(vutils.make_grid(inputs, padding=2, normalize=True).cpu(),(1,2,0)))
    plt.show()

# Called predict function based on image file or PIL image
predict(model, image_file, transform, class_names)

Training Details

The model was trained using the following configurations:

Dataset: Kuzushiji 49 MNIST Image Size: 28 Augmentation: Sharpness & Rotation Optimizer: RAdam Loss Function: CrossEntropyLoss Epochs: 200 Batch Size: 512 Learning Rate: 0.0001 Weight Decay: 0.05 Decoupled Weight Decay: True Scheduler: ReduceLROnPlateau

Evaluation Results

The model achieved the following performance metrics on the test set:

Accuracy: 0.855546357615894 Precision Macro: 0.8588960428290617 Recall Macro: 0.8551147959183673 F1 Score Macro: 0.8550432473709819

Acknowledgments

Dataset: https://www.kaggle.com/datasets/anokas/kuzushiji

License

Copyright 2024 IoriU

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Downloads last month
1
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support