Model Card: MRI Brain Tumor Classification (ResNet-18)

Model Details

  • Model Name: MRIResnet
  • Architecture: ResNet-18-based model for MRI brain tumor classification
  • Dataset: Brain Tumor MRI Dataset
  • Batch Size: 32
  • Loss Function: CrossEntropy Loss
  • Optimizer: Adam (learning rate = 1e-3)
  • Transfer Learning: Yes (pretrained ResNet-18 with modified layers)

Model Architecture

This model is based on ResNet-18, a widely used convolutional neural network, and has been adapted for MRI-based brain tumor classification.

Modifications

  • Input Channel Adaptation: The first convolutional layer is modified to accept single-channel (grayscale) MRI scans.
  • Classifier Head: The fully connected (FC) layer is replaced to output 4 classes (assuming 4 tumor categories).
  • Transfer Learning:
    • Frozen Layers: All pre-trained weights are frozen except for the modified layers.
    • Trainable Layers:
      • First convolutional layer (conv1)
      • Fully connected classification layer (fc)

Implementation

Model Definition

import torch
import torch.nn as nn
from torchvision.models import resnet18

class MRIResnet(nn.Module, PyTorchModelHubMixin):
    def __init__(self):
        super().__init__()
        self.base_model = resnet18(weights=True)
        self.base_model.conv1 = nn.Conv2d(
            1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
        )
        self.base_model.fc = nn.Linear(512, 4)

        # Freeze all layers except the modified ones
        for param in self.base_model.parameters():
            param.requires_grad = False

        for param in self.base_model.conv1.parameters():
            param.requires_grad = True
        for param in self.base_model.fc.parameters():
            param.requires_grad = True

    def forward(self, x):
        return self.base_model(x)

This model has been pushed to the Hub using the PytorchModelHubMixin integration:

Downloads last month

-

Downloads are not tracked for this model. How to track
Safetensors
Model size
11.2M params
Tensor type
F32
·
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and HF Inference API was unable to determine this model's library.