Type to search…

Vision

Train convolutional neural networks on image data with PyTorch and torchvision — from FashionMNIST warm-up to an MNIST capstone.

Goal of the lesson

By the end of this 3-hour session you should be able to:

  • read images as [C, H, W] tensors and feed them to a model in batches,
  • explain what a convolution does and why CNNs beat plain MLPs on images,
  • compute the spatial output size of a Conv2d / MaxPool2d layer,
  • assemble a small CNN with Conv2d, MaxPool2d, BatchNorm2d, Dropout,
  • train it on FashionMNIST with batched loaders,
  • visualize predictions and a confusion matrix,
  • build an MNIST digit recognizer as a capstone and test it on hand-drawn digits.

Suggested timing

BlockTopic
20 minImage tensors, transforms, DataLoader
25 minConvolution intuition: kernels, strides, padding
25 minBuild a CNN, trace shapes through it
35 minTrain and evaluate on FashionMNIST
25 minVisualize predictions, confusion matrix
50 minCapstone — MNIST digit recognizer

Setup

ps
PowerShell
uv init --python 3.12 vision
cd vision
uv add torch torchvision matplotlib
python
main.py
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)

Image tensors

PyTorch images are [C, H, W] tensors:

  • C = channels (1 for grayscale, 3 for RGB),
  • H = height in pixels,
  • W = width in pixels.

torchvision.datasets ships ready-to-use datasets that download themselves on first run.

python
main.py
train_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=transforms.ToTensor(),
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=transforms.ToTensor(),
)

class_names = train_data.classes
print(len(train_data), len(test_data))
print(class_names)

transforms.ToTensor() converts a PIL image into a [C, H, W] tensor with values in [0, 1].

python
main.py
image, label = train_data[0]
print("image shape:", image.shape)     # torch.Size([1, 28, 28])
print("label      :", label, class_names[label])

plt.imshow(image.squeeze(), cmap="gray")
plt.title(class_names[label])
plt.axis("off"); plt.show()

You're reading a preview.

Sign in to read the full article. Any account opens 10 free articles a month; students and teachers read their course pages without limit.

Sign in