resnet-cifar10-5m

Model description

This model is a Wide ResNet (≈ 4,327,754 parameters) trained from scratch on the CIFAR-10 dataset. The architecture is a wider version of ResNet20 (base width 64 instead of 16) giving ~4.3M parameters. Training was performed on a consumer AMD Radeon RX 6600 GPU using PyTorch with ROCm support.

Author: sapbot from Romarchive

Intended uses

  • Image classification on 10 everyday object categories.
  • Educational purposes – shows that medium-sized models can achieve competitive accuracy without large‑scale pre‑training.

Training procedure

  • Optimizer: SGD with momentum 0.9 and weight decay 1e‑4.
  • Learning rate: 0.1, dropped by factor 0.1 at epochs 80 and 120 (MultiStepLR).
  • Batch size: 128.
  • Epochs: up to 160 (early stopping with patience 10).
  • Data augmentation: random horizontal flip, random crop with 4‑pixel padding.
  • Input normalization: mean (0.4914, 0.4822, 0.4465), std (0.2023, 0.1994, 0.2010).

Evaluation results

Metric Value
Test accuracy (10k images) 89.84%
Test loss 0.3625
Full CIFAR-10 accuracy (60k images) 94.96%

Compare with other models

Per‑class performance

                precision    recall  f1-score   support

  0 (airplane)     0.8996    0.9320    0.9155      1000
1 (automobile)     0.9174    0.9770    0.9462      1000
      2 (bird)     0.8867    0.8370    0.8611      1000
       3 (cat)     0.7227    0.8730    0.7908      1000
      4 (deer)     0.9366    0.8870    0.9111      1000
       5 (dog)     0.8822    0.8090    0.8440      1000
      6 (frog)     0.9430    0.8940    0.9179      1000
     7 (horse)     0.9145    0.9410    0.9276      1000
      8 (ship)     0.9571    0.9360    0.9464      1000
     9 (truck)     0.9677    0.8980    0.9315      1000

      accuracy                         0.8984     10000
     macro avg     0.9027    0.8984    0.8992     10000
  weighted avg     0.9027    0.8984    0.8992     10000

Confusion matrix (row = true label, col = predicted label)

True \ Pred Pred 0 Pred 1 Pred 2 Pred 3 Pred 4 Pred 5 Pred 6 Pred 7 Pred 8 Pred 9
True 0 932 7 18 11 4 0 2 3 20 3
True 1 1 977 0 0 0 1 1 3 3 14
True 2 27 0 837 59 19 22 18 16 2 0
True 3 11 2 17 873 13 53 14 10 3 4
True 4 5 1 21 39 887 9 6 30 2 0
True 5 3 1 11 138 11 809 7 18 2 0
True 6 5 2 29 55 5 5 894 5 0 0
True 7 4 1 8 18 8 18 0 941 0 2
True 8 37 8 1 7 0 0 3 1 936 7
True 9 11 66 2 8 0 0 3 2 10 898

Model size

4,327,754 trainable parameters.

How to use

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image

# --------------------------------------------------
# Wide ResNet for CIFAR-10 (exact architecture)
# --------------------------------------------------
def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out

class ResNetWide(nn.Module):
    def __init__(self, block, num_blocks, base_width=64, num_classes=10):
        super(ResNetWide, self).__init__()
        self.in_planes = base_width
        self.conv1 = conv3x3(3, base_width)
        self.bn1 = nn.BatchNorm2d(base_width)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, base_width, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, base_width * 2, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, base_width * 4, num_blocks[2], stride=2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(base_width * 4 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

def ResNet5M():
    return ResNetWide(BasicBlock, [3, 3, 3], base_width=64)

# --------------------------------------------------
# Load the model weights
# --------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet5M().to(device)
model.load_state_dict(torch.load("best_resnet5m_cifar10.pth", map_location=device))
model.eval()

# --------------------------------------------------
# Preprocess an image (must be 32x32 RGB)
# --------------------------------------------------
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

def predict(image_path):
    image = Image.open(image_path).convert('RGB')
    input_tensor = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(input_tensor)
        _, predicted = output.max(1)
    classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    return classes[predicted.item()]

# Example usage:
# print(predict("my_cat.jpg"))

Acknowledgements

  • Original CIFAR‑10 dataset by Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton.
  • PyTorch team for the framework and ROCm support.
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train sapbot/resnet-cifar10-5m

Collection including sapbot/resnet-cifar10-5m