
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/jax_simclr_imagenette.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_auto_examples_jax_simclr_imagenette.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_jax_simclr_imagenette.py:

SimCLR on ImageNette with the JAX / Flax-NNX backend, data-parallel over GPUs.

Design choice (see the JAX backend docs): augmentation stays on CPU using the
existing torchvision pipeline and we feed **NHWC numpy arrays** into the JAX
trainer. Nothing in ``data/transforms.py`` is reimplemented — the array
boundary is the clean seam between the two backends.

Run locally (CPU smoke, tiny synthetic data, no download)::

    python examples/jax_simclr_imagenette.py --smoke

Run for real on N GPUs (data-parallel; batch is sharded across devices)::

    python examples/jax_simclr_imagenette.py --epochs 20 --batch-size 256

See ``examples/jax_simclr_imagenette.slurm`` for a 2×H200 SLURM launch.

.. GENERATED FROM PYTHON SOURCE LINES 18-248

.. code-block:: Python


    import argparse
    import time

    import jax
    import numpy as np
    from flax import nnx

    import stable_pretraining.jax as spj


    # --------------------------------------------------------------------------- #
    # Data: torchvision two-view augmentation on CPU -> NHWC numpy batches.
    # --------------------------------------------------------------------------- #
    def build_imagenette_loaders(batch_size, img_size, num_workers, gpu_aug=False):
        """Return (train_iter_fn, val_iter_fn, num_classes) backed by frgfm/imagenette.

        When ``gpu_aug`` is True the CPU pipeline does only RGB+Resize+ToTensor (cheap)
        and yields a single raw image per sample; the model's on-device transform
        generates the two SimCLR views. Otherwise the full two-view torchvision
        augmentation runs on CPU.
        """
        import torch
        from torchvision import transforms as T

        import stable_pretraining as spt

        norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        if gpu_aug:
            # Minimal CPU work — augmentation happens on the accelerator.
            train_aug = T.Compose([T.Resize((img_size, img_size)), T.ToTensor()])
        else:
            train_aug = T.Compose(
                [
                    T.RandomResizedCrop(img_size, scale=(0.2, 1.0)),
                    T.RandomHorizontalFlip(),
                    T.RandomApply([T.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
                    T.RandomGrayscale(p=0.2),
                    T.ToTensor(),
                    norm,
                ]
            )
        val_aug = T.Compose(
            [T.Resize(int(img_size * 1.15)), T.CenterCrop(img_size), T.ToTensor(), norm]
        )

        class HFView(torch.utils.data.Dataset):
            """Wrap the HF imagenette split; emit one or two augmented views per sample."""

            def __init__(self, split, aug, two_view):
                self.ds = spt.data.HFDataset(
                    "frgfm/imagenette",
                    split=split,
                    revision="refs/convert/parquet",
                    transform=None,
                )
                self.aug, self.two_view = aug, two_view

            def __len__(self):
                return len(self.ds)

            def __getitem__(self, i):
                sample = self.ds[i]
                img, label = sample["image"].convert("RGB"), int(sample["label"])
                if self.two_view:
                    return self.aug(img), self.aug(img), label
                return self.aug(img), label

        def _loader(ds, shuffle):
            return torch.utils.data.DataLoader(
                ds,
                batch_size=batch_size,
                shuffle=shuffle,
                num_workers=num_workers,
                drop_last=shuffle,
                persistent_workers=num_workers > 0,
            )

        # gpu_aug -> single-view train loader (model makes the two views on device).
        train_dl = _loader(HFView("train", train_aug, two_view=not gpu_aug), True)
        val_dl = _loader(HFView("validation", val_aug, two_view=False), False)

        def nhwc(t):  # torch [B,C,H,W] -> numpy [B,H,W,C]
            return np.ascontiguousarray(t.permute(0, 2, 3, 1).numpy())

        def train_iter_cpu():
            for v1, v2, y in train_dl:
                y = y.numpy()
                yield {
                    "views": [
                        {"image": nhwc(v1), "label": y},
                        {"image": nhwc(v2), "label": y},
                    ]
                }

        def train_iter_gpu():
            for x, y in train_dl:
                yield {"image": nhwc(x), "label": y.numpy()}

        def val_iter():
            for x, y in val_dl:
                yield {"image": nhwc(x), "label": y.numpy()}

        return (train_iter_gpu if gpu_aug else train_iter_cpu), val_iter, 10


    def build_synthetic_loaders(batch_size, img_size):
        """Tiny in-memory NHWC data for a no-download CPU smoke test."""
        rng = np.random.RandomState(0)

        def mk():
            return rng.randn(batch_size, img_size, img_size, 3).astype("float32")

        def two_view():
            y = rng.randint(0, 10, size=batch_size)
            return {"views": [{"image": mk(), "label": y}, {"image": mk(), "label": y}]}

        def single():
            return {
                "image": rng.randn(batch_size, img_size, img_size, 3).astype("float32"),
                "label": rng.randint(0, 10, size=batch_size),
            }

        train = [two_view() for _ in range(4)]
        val = [single() for _ in range(2)]
        return (lambda: iter(train)), (lambda: iter(val)), 10


    class Throughput(spj.Callback):
        """Log images/sec and loss every ``every`` steps (debug instrumentation)."""

        def __init__(self, batch_size, every=10):
            self.batch_size, self.every = batch_size, every
            self._t0 = None
            self._last = 0

        def on_train_epoch_start(self, trainer, module):
            self._t0 = time.perf_counter()
            self._last = trainer.global_step

        def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx):
            if trainer.global_step % self.every == 0:
                dt = time.perf_counter() - self._t0
                n = trainer.global_step - self._last
                ips = (n * self.batch_size) / dt if dt > 0 else 0.0
                loss = float(outputs["loss"])
                print(
                    f"  epoch {trainer.current_epoch} step {trainer.global_step} "
                    f"loss {loss:.4f}  {ips:,.0f} img/s",
                    flush=True,
                )
                self._t0 = time.perf_counter()
                self._last = trainer.global_step


    def main():
        ap = argparse.ArgumentParser()
        ap.add_argument("--epochs", type=int, default=20)
        ap.add_argument("--batch-size", type=int, default=256)
        ap.add_argument("--img-size", type=int, default=128)
        ap.add_argument("--num-workers", type=int, default=8)
        ap.add_argument("--lr", type=float, default=1e-3)
        ap.add_argument("--temperature", type=float, default=0.5)
        ap.add_argument("--low-resolution", action="store_true")
        ap.add_argument(
            "--gpu-aug",
            action="store_true",
            help="augment on the accelerator (jnp) instead of CPU torchvision",
        )
        ap.add_argument("--smoke", action="store_true", help="tiny synthetic CPU run")
        args = ap.parse_args()

        print(f"JAX devices: {jax.devices()}", flush=True)

        if args.smoke:
            train_iter, val_iter, num_classes = build_synthetic_loaders(
                args.batch_size, args.img_size
            )
        else:
            train_iter, val_iter, num_classes = build_imagenette_loaders(
                args.batch_size, args.img_size, args.num_workers, gpu_aug=args.gpu_aug
            )

        # On-device augmentation: the model turns each raw image into two views.
        transform = spj.augment.two_view_transform(args.img_size) if args.gpu_aug else None

        rngs = nnx.Rngs(0)
        backbone = spj.backbone.resnet18(rngs=rngs, low_resolution=args.low_resolution)
        model = spj.SimCLR(
            backbone=backbone,
            embed_dim=backbone.embed_dim,
            rngs=rngs,
            projector_dims=(2048, 2048, 256),
            temperature=args.temperature,
            optim={"type": "lars", "learning_rate": args.lr},
            transform=transform,
        )
        probe = spj.OnlineProbe(
            "linear_probe", probe=nnx.Linear(backbone.embed_dim, num_classes, rngs=rngs)
        )
        trainer = spj.Trainer(
            max_epochs=args.epochs,
            callbacks=[probe, spj.RankMe(), Throughput(args.batch_size)],
            data_parallel=jax.device_count() > 1,
        )
        # train_iter/val_iter are generator factories — re-create each epoch.
        trainer.fit(
            model,
            _Reiterable(train_iter),
            _Reiterable(val_iter) if not args.smoke else None,
        )
        print(
            f"DONE  loss={trainer.callback_metrics.get('fit/loss'):.4f}  "
            f"probe_acc={probe.accuracy:.4f}",
            flush=True,
        )


    class _Reiterable:
        """Adapt a generator-factory into a re-iterable object the Trainer can loop."""

        def __init__(self, factory):
            self._factory = factory

        def __iter__(self):
            return iter(self._factory())


    if __name__ == "__main__":
        main()


.. _sphx_glr_download_auto_examples_jax_simclr_imagenette.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: jax_simclr_imagenette.ipynb <jax_simclr_imagenette.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: jax_simclr_imagenette.py <jax_simclr_imagenette.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: jax_simclr_imagenette.zip <jax_simclr_imagenette.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
