uoft-cs/cifar10
Viewer • Updated • 60k • 130k • 105
This model is a Wide ResNet (≈ 4,327,754 parameters) trained from scratch on the CIFAR-10 dataset. The architecture is a wider version of ResNet20 (base width 64 instead of 16) giving ~4.3M parameters. Training was performed on a consumer AMD Radeon RX 6600 GPU using PyTorch with ROCm support.
Author: sapbot from Romarchive
| Metric | Value |
|---|---|
| Test accuracy (10k images) | 89.84% |
| Test loss | 0.3625 |
| Full CIFAR-10 accuracy (60k images) | 94.96% |
precision recall f1-score support
0 (airplane) 0.8996 0.9320 0.9155 1000
1 (automobile) 0.9174 0.9770 0.9462 1000
2 (bird) 0.8867 0.8370 0.8611 1000
3 (cat) 0.7227 0.8730 0.7908 1000
4 (deer) 0.9366 0.8870 0.9111 1000
5 (dog) 0.8822 0.8090 0.8440 1000
6 (frog) 0.9430 0.8940 0.9179 1000
7 (horse) 0.9145 0.9410 0.9276 1000
8 (ship) 0.9571 0.9360 0.9464 1000
9 (truck) 0.9677 0.8980 0.9315 1000
accuracy 0.8984 10000
macro avg 0.9027 0.8984 0.8992 10000
weighted avg 0.9027 0.8984 0.8992 10000
| True \ Pred | Pred 0 | Pred 1 | Pred 2 | Pred 3 | Pred 4 | Pred 5 | Pred 6 | Pred 7 | Pred 8 | Pred 9 |
|---|---|---|---|---|---|---|---|---|---|---|
| True 0 | 932 | 7 | 18 | 11 | 4 | 0 | 2 | 3 | 20 | 3 |
| True 1 | 1 | 977 | 0 | 0 | 0 | 1 | 1 | 3 | 3 | 14 |
| True 2 | 27 | 0 | 837 | 59 | 19 | 22 | 18 | 16 | 2 | 0 |
| True 3 | 11 | 2 | 17 | 873 | 13 | 53 | 14 | 10 | 3 | 4 |
| True 4 | 5 | 1 | 21 | 39 | 887 | 9 | 6 | 30 | 2 | 0 |
| True 5 | 3 | 1 | 11 | 138 | 11 | 809 | 7 | 18 | 2 | 0 |
| True 6 | 5 | 2 | 29 | 55 | 5 | 5 | 894 | 5 | 0 | 0 |
| True 7 | 4 | 1 | 8 | 18 | 8 | 18 | 0 | 941 | 0 | 2 |
| True 8 | 37 | 8 | 1 | 7 | 0 | 0 | 3 | 1 | 936 | 7 |
| True 9 | 11 | 66 | 2 | 8 | 0 | 0 | 3 | 2 | 10 | 898 |
4,327,754 trainable parameters.
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
# --------------------------------------------------
# Wide ResNet for CIFAR-10 (exact architecture)
# --------------------------------------------------
def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_planes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)
)
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = self.relu(out)
return out
class ResNetWide(nn.Module):
def __init__(self, block, num_blocks, base_width=64, num_classes=10):
super(ResNetWide, self).__init__()
self.in_planes = base_width
self.conv1 = conv3x3(3, base_width)
self.bn1 = nn.BatchNorm2d(base_width)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(block, base_width, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, base_width * 2, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, base_width * 4, num_blocks[2], stride=2)
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(base_width * 4 * block.expansion, num_classes)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.avg_pool(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
def ResNet5M():
return ResNetWide(BasicBlock, [3, 3, 3], base_width=64)
# --------------------------------------------------
# Load the model weights
# --------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet5M().to(device)
model.load_state_dict(torch.load("best_resnet5m_cifar10.pth", map_location=device))
model.eval()
# --------------------------------------------------
# Preprocess an image (must be 32x32 RGB)
# --------------------------------------------------
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
def predict(image_path):
image = Image.open(image_path).convert('RGB')
input_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_tensor)
_, predicted = output.max(1)
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
return classes[predicted.item()]
# Example usage:
# print(predict("my_cat.jpg"))