Escribe para buscar…

Custom datasets

Load your own data into PyTorch with Dataset, DataLoader, ImageFolder, samplers and a transforms pipeline.

Esta página todavía no se ha traducido — se muestra en su idioma original:English

Goal of the lesson

By the end of this 3-hour session you should be able to:

  • explain the contract of torch.utils.data.Dataset,
  • use ImageFolder for the common image-classification layout,
  • write your own Dataset for any data you have on disk,
  • compose a transforms pipeline with deterministic preprocessing and stochastic augmentation,
  • balance unbalanced batches with WeightedRandomSampler,
  • build a small image classifier from a folder of images you assembled yourself.

Suggested timing

BlockTopic
15 minWhy custom datasets, the Dataset contract
25 minGet the data, ImageFolder
30 minTransforms — deterministic vs. augmentation
30 minWrite a Dataset from scratch
25 minDataLoader knobs and WeightedRandomSampler
55 minCapstone — your own image classifier

Why custom datasets

Built-in datasets like FashionMNIST and MNIST are training wheels. The moment you have a real project you’ll be loading your own files: photos in folders, audio in WAV files, sensor logs in CSV, MRI scans in DICOM, etc.

PyTorch has a small, composable API for that:

Building blockPurpose
torch.utils.data.DatasetYour data, indexed by integer.
torch.utils.data.DataLoaderWraps a Dataset to deliver batches, shuffling, parallel loading.
torchvision.datasets.ImageFolderA ready-made Dataset for images organized by folder.
torchvision.transformsImage preprocessing and augmentation.
torch.utils.data.SamplerDecides which indices to draw on each epoch.

We’ll use a small subset of Food-101 (pizza, steak, sushi) as a running example. The same code patterns work for medical images, satellite images, audio spectrograms, or anything else you load from disk.

Setup

ps
PowerShell
uv init --python 3.12 datasets
cd datasets
uv add torch torchvision matplotlib pillow requests
python
main.py
from pathlib import Path
import zipfile
import requests

import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torchvision import datasets, transforms
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)

Get the data

The dataset ships as a zip on the mrdbourke/pytorch-deep-learning repo. Download it once and unpack it into data/.

python
main.py
DATA_PATH = Path("data")
IMAGE_PATH = DATA_PATH / "pizza_steak_sushi"

if not IMAGE_PATH.is_dir():
    IMAGE_PATH.mkdir(parents=True, exist_ok=True)

    url = "https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip"
    zip_file = DATA_PATH / "pizza_steak_sushi.zip"
    zip_file.write_bytes(requests.get(url).content)

    with zipfile.ZipFile(zip_file, "r") as z:
        z.extractall(IMAGE_PATH)
    zip_file.unlink()

train_dir = IMAGE_PATH / "train"
test_dir = IMAGE_PATH / "test"
print(list(train_dir.iterdir()))

The folder structure is the standard convention for image classification:

pizza_steak_sushi/
├── train/
│   ├── pizza/
│   ├── steak/
│   └── sushi/
└── test/
    ├── pizza/
    ├── steak/
    └── sushi/

Estás leyendo una vista previa.

Inicia sesión para leer el artículo completo. Cualquier cuenta abre 4 artículos gratuitos al mes; el alumnado y el profesorado leen las páginas de su curso sin límite.

Iniciar sesión