Type to search…

Transfer learning

Reuse pretrained torchvision models on your own dataset to reach strong accuracy with very little data and very little training time.

Goal of the lesson

By the end of this 3-hour session you should be able to:

  • explain why pretrained features generalize across vision tasks,
  • distinguish feature extraction from fine-tuning,
  • load a pretrained torchvision model and inspect its parts,
  • replace the classifier head and freeze the backbone,
  • apply the model’s own preprocessing transforms,
  • train a classifier on a small dataset and reach >90% accuracy in minutes,
  • compare two architectures and report on the winner.

Suggested timing

BlockTopic
15 minWhat transfer learning is and why it works
25 minLoad a pretrained model, inspect its layers
25 minReplace head, freeze backbone, train
30 minEvaluate, predict, save
25 minFine-tuning the backbone
60 minCapstone — pretrained model on your own dataset

Why transfer learning

Training a CNN from scratch needs a lot of data and a lot of compute. The first few layers of any vision network learn very generic features — edges, corners, textures, colors — that are useful for almost any image task. The deeper layers combine those into more task-specific patterns.

Transfer learning is the trick that makes deep learning practical for small projects:

  1. take a model pretrained on a large dataset (typically ImageNet, ~1.3M images, 1000 classes),
  2. keep most of its weights,
  3. retrain only the final classifier on your data.

Two flavors:

ApproachTrainable parametersWhen to use
Feature extractionOnly the new classifier headSmall dataset, similar domain. Default starting point.
Fine-tuningThe new head + (some of) the backbone, with a tiny learning rateLarger dataset, or domain that drifts noticeably from ImageNet (medical, satellite, drawings).

We’ll start with feature extraction and add fine-tuning at the end.

Setup

ps
PowerShell
uv init --python 3.12 transfer
cd transfer
uv add torch torchvision matplotlib pillow requests
python
main.py
from pathlib import Path
import zipfile
import requests

import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, models

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)

We’ll reuse the pizza/steak/sushi dataset from the datasets chapter:

python
main.py
DATA_PATH = Path("data")
IMAGE_PATH = DATA_PATH / "pizza_steak_sushi"

if not IMAGE_PATH.is_dir():
    IMAGE_PATH.mkdir(parents=True, exist_ok=True)
    url = "https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip"
    zip_file = DATA_PATH / "pizza_steak_sushi.zip"
    zip_file.write_bytes(requests.get(url).content)
    with zipfile.ZipFile(zip_file, "r") as z:
        z.extractall(IMAGE_PATH)
    zip_file.unlink()

train_dir = IMAGE_PATH / "train"
test_dir = IMAGE_PATH / "test"

Load a pretrained model

torchvision.models exposes architectures plus their pretrained weights. Each set of weights advertises:

  • the model itself,
  • the transforms it expects at inference time,
  • metadata (number of parameters, ImageNet accuracy).
python
main.py
weights = models.EfficientNet_B0_Weights.DEFAULT

print("acc on ImageNet:", weights.meta["_metrics"]["ImageNet-1K"])
print("input size      :", weights.transforms())
print("classes         :", len(weights.meta["categories"]))

model = models.efficientnet_b0(weights=weights).to(device)

You're reading a preview.

Sign in to read the full article. Any account opens 10 free articles a month; students and teachers read their course pages without limit.

Sign in