Custom datasets
Load your own data into PyTorch with Dataset, DataLoader, ImageFolder, samplers and a transforms pipeline.
Goal of the lesson
By the end of this 3-hour session you should be able to:
- explain the contract of
torch.utils.data.Dataset, - use
ImageFolderfor the common image-classification layout, - write your own
Datasetfor any data you have on disk, - compose a transforms pipeline with deterministic preprocessing and stochastic augmentation,
- balance unbalanced batches with
WeightedRandomSampler, - build a small image classifier from a folder of images you assembled yourself.
Suggested timing
| Block | Topic |
|---|---|
| 15 min | Why custom datasets, the Dataset contract |
| 25 min | Get the data, ImageFolder |
| 30 min | Transforms — deterministic vs. augmentation |
| 30 min | Write a Dataset from scratch |
| 25 min | DataLoader knobs and WeightedRandomSampler |
| 55 min | Capstone — your own image classifier |
Why custom datasets
Built-in datasets like FashionMNIST and MNIST are training wheels. The moment you have a real project you’ll be loading your own files: photos in folders, audio in WAV files, sensor logs in CSV, MRI scans in DICOM, etc.
PyTorch has a small, composable API for that:
| Building block | Purpose |
|---|---|
torch.utils.data.Dataset | Your data, indexed by integer. |
torch.utils.data.DataLoader | Wraps a Dataset to deliver batches, shuffling, parallel loading. |
torchvision.datasets.ImageFolder | A ready-made Dataset for images organized by folder. |
torchvision.transforms | Image preprocessing and augmentation. |
torch.utils.data.Sampler | Decides which indices to draw on each epoch. |
We’ll use a small subset of Food-101 (pizza, steak, sushi) as a running example. The same code patterns work for medical images, satellite images, audio spectrograms, or anything else you load from disk.
Setup
uv init --python 3.12 datasets
cd datasets
uv add torch torchvision matplotlib pillow requests
=
Get the data
The dataset ships as a zip on the mrdbourke/pytorch-deep-learning repo. Download it once and unpack it into data/.
=
= /
=
= /
= /
= /
The folder structure is the standard convention for image classification:
pizza_steak_sushi/
├── train/
│ ├── pizza/
│ ├── steak/
│ └── sushi/
└── test/
├── pizza/
├── steak/
└── sushi/The directory name is the label. This is so common that ImageFolder automates it.
Inspect the data
Before doing anything else, look at a few samples:
=
, =
=
;
; The images vary in size, lighting, and framing. Welcome to real data.
Transforms
Models expect tensors of a fixed size and dtype. torchvision.transforms builds a pipeline that runs on each image as it is loaded.
= ,
,
,
,
])
= ,
,
])Two rules of thumb:
- The same deterministic preprocessing (resize, normalize) is applied to both train and test.
- Augmentations (random flips, crops, color jitter,
TrivialAugmentWide) belong only on the training transforms. The test set must measure how well you do on the real data, not on randomly modified data.
transforms.ToTensor() converts a PIL image to a [C, H, W] tensor in [0, 1]. Put it last — most other transforms operate on PIL images, not tensors.
Visualize what augmentation does
=
=
, =
Run it twice. The augmented versions vary; the underlying image is the same.
Try it — design augmentations for a new task
For each task, suggest transforms you would and wouldn’t include in training. Why?
- Recognising pets vs. wild animals from photos.
- Reading handwritten digits like MNIST.
- Detecting cracks in concrete from drone imagery.
Show solution
- Pets: horizontal flip ✓ (a flipped dog is still a dog), color jitter ✓ (lighting varies), vertical flip ✗ (cats rarely walk on the ceiling).
- Digits: small rotation/translation ✓, horizontal flip ✗ (“3” flipped is not a 3), color jitter ✗ (the dataset is grayscale).
- Concrete cracks: rotations / flips in any direction ✓ (drone orientation is arbitrary), color jitter ✓ (sun changes over the day).
Option 1 — ImageFolder
When the data is already organized by folder, ImageFolder is the shortest path from disk to Dataset:
=
=
# ['pizza', 'steak', 'sushi']
# {'pizza': 0, 'steak': 1, 'sushi': 2}
, =
# torch.Size([3, 64, 64]) 0ImageFolder walks the directory once at construction time, builds the list of paths, and reads the corresponding image lazily on each __getitem__.
Option 2 — Write your own Dataset
ImageFolder covers most needs but you’ll often have to write a Dataset yourself: a CSV of (path, label) pairs, multi-label data, paired images, audio, sensor logs, etc.
A Dataset only has to implement three methods:
=
=
=
=
return
=
=
=
=
return ,
=
, =
The contract:
| Method | What it returns |
|---|---|
__len__ | Number of samples. |
__getitem__(i) | The i-th (input, target) pair. |
__init__ | Whatever indexing/lookup you need to make __getitem__ cheap. |
Two important properties:
- Index-by-int.
dataset[0],dataset[1], …dataset[len(dataset) - 1]. TheDataLoaderdecides the order. - Lazy. Don’t load all images in
__init__— store paths and read on demand. Otherwise large datasets won’t fit in memory.
A Dataset from a CSV
A common real-world layout is images/ plus a CSV with path,label rows:
,
,
=
=
=
=
=
return
, =
=
=
return , You’ll reuse this pattern over and over — just swap the storage layer (CSV, JSON, database, S3) and the loader (Image.open, torchaudio.load, pandas.read_parquet).
Try it — multi-label dataset
Imagine each photo has multiple labels (["beach", "sunset"], ["cat", "indoors"]). Sketch the changes you’d make to __getitem__ and to the loss function.
Show solution
__getitem__ returns a multi-hot float tensor of length num_classes (one 1.0 for each label) instead of a single int. Loss switches from CrossEntropyLoss to BCEWithLogitsLoss. The model output stays as num_classes logits, but predictions become “logit > 0” per class instead of argmax.
DataLoader
DataLoader wraps any Dataset and yields batches.
= 32
= ,
=,
=True,
=0,
)
= ,
=,
=False,
=0,
)
, =
# torch.Size([32, 3, 64, 64]) torch.Size([32])The most useful parameters:
| Parameter | Meaning |
|---|---|
batch_size | How many samples per batch. |
shuffle | Reorder indices every epoch. True for train, False for test. |
num_workers | Background worker processes. 0 = main thread; >0 = parallel loading. |
pin_memory | Faster CPU→GPU transfers (set True if using a GPU). |
drop_last | Drop the last batch if it’s smaller than batch_size. |
sampler | Replaces the default sequential/random sampler — see below. |
On Windows, num_workers > 0 requires the training code to live inside an if __name__ == "__main__": guard, because each worker re-imports the script. Start with num_workers=0 while developing, then bump it up once everything works.
Inspect a batch
, =
# CHW -> HWC for matplotlib
;
permute(1, 2, 0) is the matplotlib convention — [H, W, C] instead of PyTorch’s [C, H, W].
Class imbalance — WeightedRandomSampler
If one class has 10× more samples than another, the model will learn to predict the majority class even when it shouldn’t. WeightedRandomSampler fixes that by drawing minority-class samples more often.
# Per-class counts
=
# Per-sample weight: inverse of its class count
=
=
= =,
=,
=True,
)
= ,
=,
=, # mutually exclusive with shuffle=True
=0,
)
# Confirm batches are now roughly balanced
=
The pizza/steak/sushi dataset is already balanced, but the technique transfers to any imbalanced classification task.
Train a small CNN
The training loop is the same as the vision chapter. The only changes are the input channels (3 for RGB) and the spatial size after pooling.
= ,
,
,
)
= ,
,
,
)
= ,
,
)
= ; = ; return
= For 64×64 input, two MaxPool2d(2) stages bring you down to 16×16, hence hidden * 16 * 16.
With this little data the model will not generalize well — that is exactly the problem the transfer-learning chapter solves. For now, train for a few epochs and observe the gap between train and test accuracy.
=
=
, = ,
=
; ;
= = 0
, = ,
=
+= ; +=
Exercises
Warm-up
- Print the first 5 paths in
train_data.samplesand confirm__getitem__(i)opens the matching file. - Build a histogram of image sizes (width × height) before resizing. How varied is the data?
- Add
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(ImageNet stats) to both train and test transforms. Why exactly the same to both?
Custom datasets
- Write a
Datasetthat reads a CSV with columns(path, label)instead of relying on the folder layout. - Modify
ImageFolderCustomto also accept*.pngand*.jpeg. - Write a
Datasetwhose targets are bounding-box coordinates(x1, y1, x2, y2)parsed from a JSON file. (Imagine a detection problem.)
Augmentation
- Increase the resize to
(128, 128). What changes about the classifier’s input size? - Replace
TrivialAugmentWidewithRandAugment(num_ops=2, magnitude=9). Did test accuracy improve?
Sampler
- Delete half of the
pizzafolder. Train once without a sampler and once withWeightedRandomSampler. Compare per-class accuracy. - Set
num_workers=2inside anif __name__ == "__main__":block on Windows and measure the speedup.
Show solution
For exercise 9 the per-class accuracy gap should be much smaller with the weighted sampler. Build it like this:
, = ,
= ; =
=
=
+= 1
+= 1
=
Capstone — your own image classifier
Build a 3-class image classifier from scratch using your own photos. This trains every part of the pipeline you’ve learned: collecting data, organizing it, transforming it, loading it, and training on it.
Step 1 — pick three classes
Anything you can take 30+ photos of: pencil/pen/marker, apple/orange/banana, your three favourite mugs. Aim for variety — different lighting, angles, backgrounds.
Step 2 — assemble the folder
Take 30 training photos and 10 test photos per class. Resize them so none is larger than ~1000 px on the long side (your phone images are way too big). Lay them out as:
my_dataset/
├── train/
│ ├── class_a/
│ │ ├── 0001.jpg
│ │ └── ...
│ ├── class_b/
│ └── class_c/
└── test/
├── class_a/
├── class_b/
└── class_c/Step 3 — load it
=
=
= ,
,
,
,
,
])
= ,
,
,
])
=
=
=
=
Step 4 — train a small CNN
= , , , ,
, , , ,
, , , ,
)
= ,
,
, ,
,
)
return
=
=
=
= 30
, = ,
=
; ;
= = 0
, = ,
=
+= ; +=
With 30 photos per class you should not expect great test accuracy — 65–80% is typical. The point of this capstone is the pipeline, not the score. The next chapter, Transfer learning, shows how to push that number much higher with the same data.
Step 5 — predict on a new photo
=
=
=
=
return ,
, =
Stretch goals
- Compute and display a confusion matrix on your test set (see the vision chapter).
- Add
WeightedRandomSamplerif some classes ended up with more samples than others. - Save the dataset class list to a JSON file alongside the model so you don’t depend on folder ordering at inference time.
Recap
- A
Datasetis anything with__len__and__getitem__. Keep__init__cheap, load on demand. ImageFolderis the shortcut fortrain/<class>/<image>.jpglayouts.- Build a transforms pipeline: deterministic preprocessing + augmentation in train, deterministic preprocessing only in test.
DataLoaderturns aDatasetinto batches.shufflefor train, not for test.num_workers > 0needsif __name__ == "__main__":on Windows.- For imbalanced data, swap
shuffle=Truefor aWeightedRandomSampler.
The next chapter, Transfer learning, shows you how to leverage models pretrained on millions of images to get strong results from a tiny dataset like the one you just built.