Type to search…

Custom datasets

Load your own data into PyTorch with Dataset, DataLoader, ImageFolder, samplers and a transforms pipeline.

Goal of the lesson

By the end of this 3-hour session you should be able to:

  • explain the contract of torch.utils.data.Dataset,
  • use ImageFolder for the common image-classification layout,
  • write your own Dataset for any data you have on disk,
  • compose a transforms pipeline with deterministic preprocessing and stochastic augmentation,
  • balance unbalanced batches with WeightedRandomSampler,
  • build a small image classifier from a folder of images you assembled yourself.

Suggested timing

BlockTopic
15 minWhy custom datasets, the Dataset contract
25 minGet the data, ImageFolder
30 minTransforms — deterministic vs. augmentation
30 minWrite a Dataset from scratch
25 minDataLoader knobs and WeightedRandomSampler
55 minCapstone — your own image classifier

Why custom datasets

Built-in datasets like FashionMNIST and MNIST are training wheels. The moment you have a real project you’ll be loading your own files: photos in folders, audio in WAV files, sensor logs in CSV, MRI scans in DICOM, etc.

PyTorch has a small, composable API for that:

Building blockPurpose
torch.utils.data.DatasetYour data, indexed by integer.
torch.utils.data.DataLoaderWraps a Dataset to deliver batches, shuffling, parallel loading.
torchvision.datasets.ImageFolderA ready-made Dataset for images organized by folder.
torchvision.transformsImage preprocessing and augmentation.
torch.utils.data.SamplerDecides which indices to draw on each epoch.

We’ll use a small subset of Food-101 (pizza, steak, sushi) as a running example. The same code patterns work for medical images, satellite images, audio spectrograms, or anything else you load from disk.

Setup

ps
PowerShell
uv init --python 3.12 datasets
cd datasets
uv add torch torchvision matplotlib pillow requests
python
main.py
from pathlib import Path
import zipfile
import requests

import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torchvision import datasets, transforms
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)

Get the data

The dataset ships as a zip on the mrdbourke/pytorch-deep-learning repo. Download it once and unpack it into data/.

python
main.py
DATA_PATH = Path("data")
IMAGE_PATH = DATA_PATH / "pizza_steak_sushi"

if not IMAGE_PATH.is_dir():
    IMAGE_PATH.mkdir(parents=True, exist_ok=True)

    url = "https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip"
    zip_file = DATA_PATH / "pizza_steak_sushi.zip"
    zip_file.write_bytes(requests.get(url).content)

    with zipfile.ZipFile(zip_file, "r") as z:
        z.extractall(IMAGE_PATH)
    zip_file.unlink()

train_dir = IMAGE_PATH / "train"
test_dir = IMAGE_PATH / "test"
print(list(train_dir.iterdir()))

The folder structure is the standard convention for image classification:

pizza_steak_sushi/
├── train/
│   ├── pizza/
│   ├── steak/
│   └── sushi/
└── test/
    ├── pizza/
    ├── steak/
    └── sushi/

You're reading a preview.

Sign in to read the full article. Any account opens 10 free articles a month; students and teachers read their course pages without limit.

Sign in