Note
Go to the end to download the full example code.
SimCLR on ImageNette with the JAX / Flax-NNX backend, data-parallel over GPUs.
Design choice (see the JAX backend docs): augmentation stays on CPU using the
existing torchvision pipeline and we feed NHWC numpy arrays into the JAX
trainer. Nothing in data/transforms.py is reimplemented — the array
boundary is the clean seam between the two backends.
Run locally (CPU smoke, tiny synthetic data, no download):
python examples/jax_simclr_imagenette.py --smoke
Run for real on N GPUs (data-parallel; batch is sharded across devices):
python examples/jax_simclr_imagenette.py --epochs 20 --batch-size 256
See examples/jax_simclr_imagenette.slurm for a 2×H200 SLURM launch.
import argparse
import time
import jax
import numpy as np
from flax import nnx
import stable_pretraining.jax as spj
# --------------------------------------------------------------------------- #
# Data: torchvision two-view augmentation on CPU -> NHWC numpy batches.
# --------------------------------------------------------------------------- #
def build_imagenette_loaders(batch_size, img_size, num_workers, gpu_aug=False):
"""Return (train_iter_fn, val_iter_fn, num_classes) backed by frgfm/imagenette.
When ``gpu_aug`` is True the CPU pipeline does only RGB+Resize+ToTensor (cheap)
and yields a single raw image per sample; the model's on-device transform
generates the two SimCLR views. Otherwise the full two-view torchvision
augmentation runs on CPU.
"""
import torch
from torchvision import transforms as T
import stable_pretraining as spt
norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
if gpu_aug:
# Minimal CPU work — augmentation happens on the accelerator.
train_aug = T.Compose([T.Resize((img_size, img_size)), T.ToTensor()])
else:
train_aug = T.Compose(
[
T.RandomResizedCrop(img_size, scale=(0.2, 1.0)),
T.RandomHorizontalFlip(),
T.RandomApply([T.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
T.RandomGrayscale(p=0.2),
T.ToTensor(),
norm,
]
)
val_aug = T.Compose(
[T.Resize(int(img_size * 1.15)), T.CenterCrop(img_size), T.ToTensor(), norm]
)
class HFView(torch.utils.data.Dataset):
"""Wrap the HF imagenette split; emit one or two augmented views per sample."""
def __init__(self, split, aug, two_view):
self.ds = spt.data.HFDataset(
"frgfm/imagenette",
split=split,
revision="refs/convert/parquet",
transform=None,
)
self.aug, self.two_view = aug, two_view
def __len__(self):
return len(self.ds)
def __getitem__(self, i):
sample = self.ds[i]
img, label = sample["image"].convert("RGB"), int(sample["label"])
if self.two_view:
return self.aug(img), self.aug(img), label
return self.aug(img), label
def _loader(ds, shuffle):
return torch.utils.data.DataLoader(
ds,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
drop_last=shuffle,
persistent_workers=num_workers > 0,
)
# gpu_aug -> single-view train loader (model makes the two views on device).
train_dl = _loader(HFView("train", train_aug, two_view=not gpu_aug), True)
val_dl = _loader(HFView("validation", val_aug, two_view=False), False)
def nhwc(t): # torch [B,C,H,W] -> numpy [B,H,W,C]
return np.ascontiguousarray(t.permute(0, 2, 3, 1).numpy())
def train_iter_cpu():
for v1, v2, y in train_dl:
y = y.numpy()
yield {
"views": [
{"image": nhwc(v1), "label": y},
{"image": nhwc(v2), "label": y},
]
}
def train_iter_gpu():
for x, y in train_dl:
yield {"image": nhwc(x), "label": y.numpy()}
def val_iter():
for x, y in val_dl:
yield {"image": nhwc(x), "label": y.numpy()}
return (train_iter_gpu if gpu_aug else train_iter_cpu), val_iter, 10
def build_synthetic_loaders(batch_size, img_size):
"""Tiny in-memory NHWC data for a no-download CPU smoke test."""
rng = np.random.RandomState(0)
def mk():
return rng.randn(batch_size, img_size, img_size, 3).astype("float32")
def two_view():
y = rng.randint(0, 10, size=batch_size)
return {"views": [{"image": mk(), "label": y}, {"image": mk(), "label": y}]}
def single():
return {
"image": rng.randn(batch_size, img_size, img_size, 3).astype("float32"),
"label": rng.randint(0, 10, size=batch_size),
}
train = [two_view() for _ in range(4)]
val = [single() for _ in range(2)]
return (lambda: iter(train)), (lambda: iter(val)), 10
class Throughput(spj.Callback):
"""Log images/sec and loss every ``every`` steps (debug instrumentation)."""
def __init__(self, batch_size, every=10):
self.batch_size, self.every = batch_size, every
self._t0 = None
self._last = 0
def on_train_epoch_start(self, trainer, module):
self._t0 = time.perf_counter()
self._last = trainer.global_step
def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx):
if trainer.global_step % self.every == 0:
dt = time.perf_counter() - self._t0
n = trainer.global_step - self._last
ips = (n * self.batch_size) / dt if dt > 0 else 0.0
loss = float(outputs["loss"])
print(
f" epoch {trainer.current_epoch} step {trainer.global_step} "
f"loss {loss:.4f} {ips:,.0f} img/s",
flush=True,
)
self._t0 = time.perf_counter()
self._last = trainer.global_step
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--epochs", type=int, default=20)
ap.add_argument("--batch-size", type=int, default=256)
ap.add_argument("--img-size", type=int, default=128)
ap.add_argument("--num-workers", type=int, default=8)
ap.add_argument("--lr", type=float, default=1e-3)
ap.add_argument("--temperature", type=float, default=0.5)
ap.add_argument("--low-resolution", action="store_true")
ap.add_argument(
"--gpu-aug",
action="store_true",
help="augment on the accelerator (jnp) instead of CPU torchvision",
)
ap.add_argument("--smoke", action="store_true", help="tiny synthetic CPU run")
args = ap.parse_args()
print(f"JAX devices: {jax.devices()}", flush=True)
if args.smoke:
train_iter, val_iter, num_classes = build_synthetic_loaders(
args.batch_size, args.img_size
)
else:
train_iter, val_iter, num_classes = build_imagenette_loaders(
args.batch_size, args.img_size, args.num_workers, gpu_aug=args.gpu_aug
)
# On-device augmentation: the model turns each raw image into two views.
transform = spj.augment.two_view_transform(args.img_size) if args.gpu_aug else None
rngs = nnx.Rngs(0)
backbone = spj.backbone.resnet18(rngs=rngs, low_resolution=args.low_resolution)
model = spj.SimCLR(
backbone=backbone,
embed_dim=backbone.embed_dim,
rngs=rngs,
projector_dims=(2048, 2048, 256),
temperature=args.temperature,
optim={"type": "lars", "learning_rate": args.lr},
transform=transform,
)
probe = spj.OnlineProbe(
"linear_probe", probe=nnx.Linear(backbone.embed_dim, num_classes, rngs=rngs)
)
trainer = spj.Trainer(
max_epochs=args.epochs,
callbacks=[probe, spj.RankMe(), Throughput(args.batch_size)],
data_parallel=jax.device_count() > 1,
)
# train_iter/val_iter are generator factories — re-create each epoch.
trainer.fit(
model,
_Reiterable(train_iter),
_Reiterable(val_iter) if not args.smoke else None,
)
print(
f"DONE loss={trainer.callback_metrics.get('fit/loss'):.4f} "
f"probe_acc={probe.accuracy:.4f}",
flush=True,
)
class _Reiterable:
"""Adapt a generator-factory into a re-iterable object the Trainer can loop."""
def __init__(self, factory):
self._factory = factory
def __iter__(self):
return iter(self._factory())
if __name__ == "__main__":
main()