"""SimCLR on ImageNette with the JAX / Flax-NNX backend, data-parallel over GPUs.

Design choice (see the JAX backend docs): augmentation stays on CPU using the
existing torchvision pipeline and we feed **NHWC numpy arrays** into the JAX
trainer. Nothing in ``data/transforms.py`` is reimplemented — the array
boundary is the clean seam between the two backends.

Run locally (CPU smoke, tiny synthetic data, no download)::

    python examples/jax_simclr_imagenette.py --smoke

Run for real on N GPUs (data-parallel; batch is sharded across devices)::

    python examples/jax_simclr_imagenette.py --epochs 20 --batch-size 256

See ``examples/jax_simclr_imagenette.slurm`` for a 2×H200 SLURM launch.
"""

import argparse
import time

import jax
import numpy as np
from flax import nnx

import stable_pretraining.jax as spj


# --------------------------------------------------------------------------- #
# Data: torchvision two-view augmentation on CPU -> NHWC numpy batches.
# --------------------------------------------------------------------------- #
def build_imagenette_loaders(batch_size, img_size, num_workers, gpu_aug=False):
    """Return (train_iter_fn, val_iter_fn, num_classes) backed by frgfm/imagenette.

    When ``gpu_aug`` is True the CPU pipeline does only RGB+Resize+ToTensor (cheap)
    and yields a single raw image per sample; the model's on-device transform
    generates the two SimCLR views. Otherwise the full two-view torchvision
    augmentation runs on CPU.
    """
    import torch
    from torchvision import transforms as T

    import stable_pretraining as spt

    norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    if gpu_aug:
        # Minimal CPU work — augmentation happens on the accelerator.
        train_aug = T.Compose([T.Resize((img_size, img_size)), T.ToTensor()])
    else:
        train_aug = T.Compose(
            [
                T.RandomResizedCrop(img_size, scale=(0.2, 1.0)),
                T.RandomHorizontalFlip(),
                T.RandomApply([T.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
                T.RandomGrayscale(p=0.2),
                T.ToTensor(),
                norm,
            ]
        )
    val_aug = T.Compose(
        [T.Resize(int(img_size * 1.15)), T.CenterCrop(img_size), T.ToTensor(), norm]
    )

    class HFView(torch.utils.data.Dataset):
        """Wrap the HF imagenette split; emit one or two augmented views per sample."""

        def __init__(self, split, aug, two_view):
            self.ds = spt.data.HFDataset(
                "frgfm/imagenette",
                split=split,
                revision="refs/convert/parquet",
                transform=None,
            )
            self.aug, self.two_view = aug, two_view

        def __len__(self):
            return len(self.ds)

        def __getitem__(self, i):
            sample = self.ds[i]
            img, label = sample["image"].convert("RGB"), int(sample["label"])
            if self.two_view:
                return self.aug(img), self.aug(img), label
            return self.aug(img), label

    def _loader(ds, shuffle):
        return torch.utils.data.DataLoader(
            ds,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=shuffle,
            persistent_workers=num_workers > 0,
        )

    # gpu_aug -> single-view train loader (model makes the two views on device).
    train_dl = _loader(HFView("train", train_aug, two_view=not gpu_aug), True)
    val_dl = _loader(HFView("validation", val_aug, two_view=False), False)

    def nhwc(t):  # torch [B,C,H,W] -> numpy [B,H,W,C]
        return np.ascontiguousarray(t.permute(0, 2, 3, 1).numpy())

    def train_iter_cpu():
        for v1, v2, y in train_dl:
            y = y.numpy()
            yield {
                "views": [
                    {"image": nhwc(v1), "label": y},
                    {"image": nhwc(v2), "label": y},
                ]
            }

    def train_iter_gpu():
        for x, y in train_dl:
            yield {"image": nhwc(x), "label": y.numpy()}

    def val_iter():
        for x, y in val_dl:
            yield {"image": nhwc(x), "label": y.numpy()}

    return (train_iter_gpu if gpu_aug else train_iter_cpu), val_iter, 10


def build_synthetic_loaders(batch_size, img_size):
    """Tiny in-memory NHWC data for a no-download CPU smoke test."""
    rng = np.random.RandomState(0)

    def mk():
        return rng.randn(batch_size, img_size, img_size, 3).astype("float32")

    def two_view():
        y = rng.randint(0, 10, size=batch_size)
        return {"views": [{"image": mk(), "label": y}, {"image": mk(), "label": y}]}

    def single():
        return {
            "image": rng.randn(batch_size, img_size, img_size, 3).astype("float32"),
            "label": rng.randint(0, 10, size=batch_size),
        }

    train = [two_view() for _ in range(4)]
    val = [single() for _ in range(2)]
    return (lambda: iter(train)), (lambda: iter(val)), 10


class Throughput(spj.Callback):
    """Log images/sec and loss every ``every`` steps (debug instrumentation)."""

    def __init__(self, batch_size, every=10):
        self.batch_size, self.every = batch_size, every
        self._t0 = None
        self._last = 0

    def on_train_epoch_start(self, trainer, module):
        self._t0 = time.perf_counter()
        self._last = trainer.global_step

    def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx):
        if trainer.global_step % self.every == 0:
            dt = time.perf_counter() - self._t0
            n = trainer.global_step - self._last
            ips = (n * self.batch_size) / dt if dt > 0 else 0.0
            loss = float(outputs["loss"])
            print(
                f"  epoch {trainer.current_epoch} step {trainer.global_step} "
                f"loss {loss:.4f}  {ips:,.0f} img/s",
                flush=True,
            )
            self._t0 = time.perf_counter()
            self._last = trainer.global_step


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--epochs", type=int, default=20)
    ap.add_argument("--batch-size", type=int, default=256)
    ap.add_argument("--img-size", type=int, default=128)
    ap.add_argument("--num-workers", type=int, default=8)
    ap.add_argument("--lr", type=float, default=1e-3)
    ap.add_argument("--temperature", type=float, default=0.5)
    ap.add_argument("--low-resolution", action="store_true")
    ap.add_argument(
        "--gpu-aug",
        action="store_true",
        help="augment on the accelerator (jnp) instead of CPU torchvision",
    )
    ap.add_argument("--smoke", action="store_true", help="tiny synthetic CPU run")
    args = ap.parse_args()

    print(f"JAX devices: {jax.devices()}", flush=True)

    if args.smoke:
        train_iter, val_iter, num_classes = build_synthetic_loaders(
            args.batch_size, args.img_size
        )
    else:
        train_iter, val_iter, num_classes = build_imagenette_loaders(
            args.batch_size, args.img_size, args.num_workers, gpu_aug=args.gpu_aug
        )

    # On-device augmentation: the model turns each raw image into two views.
    transform = spj.augment.two_view_transform(args.img_size) if args.gpu_aug else None

    rngs = nnx.Rngs(0)
    backbone = spj.backbone.resnet18(rngs=rngs, low_resolution=args.low_resolution)
    model = spj.SimCLR(
        backbone=backbone,
        embed_dim=backbone.embed_dim,
        rngs=rngs,
        projector_dims=(2048, 2048, 256),
        temperature=args.temperature,
        optim={"type": "lars", "learning_rate": args.lr},
        transform=transform,
    )
    probe = spj.OnlineProbe(
        "linear_probe", probe=nnx.Linear(backbone.embed_dim, num_classes, rngs=rngs)
    )
    trainer = spj.Trainer(
        max_epochs=args.epochs,
        callbacks=[probe, spj.RankMe(), Throughput(args.batch_size)],
        data_parallel=jax.device_count() > 1,
    )
    # train_iter/val_iter are generator factories — re-create each epoch.
    trainer.fit(
        model,
        _Reiterable(train_iter),
        _Reiterable(val_iter) if not args.smoke else None,
    )
    print(
        f"DONE  loss={trainer.callback_metrics.get('fit/loss'):.4f}  "
        f"probe_acc={probe.accuracy:.4f}",
        flush=True,
    )


class _Reiterable:
    """Adapt a generator-factory into a re-iterable object the Trainer can loop."""

    def __init__(self, factory):
        self._factory = factory

    def __iter__(self):
        return iter(self._factory())


if __name__ == "__main__":
    main()
