Source code for stable_datasets.images.imagenet_1k

import io
import tarfile
from pathlib import Path

from PIL import Image as PILImage

from stable_datasets.schema import ClassLabel, DatasetInfo, Features, Image, Version
from stable_datasets.splits import Split, SplitGenerator
from stable_datasets.utils import BaseDatasetBuilder, download


def _default_class_names(count: int) -> list[str]:
    width = max(2, len(str(count - 1)))
    return [f"class_{idx:0{width}d}" for idx in range(count)]


class _ImageNetArchiveMixin:
    def _iter_inner_images(self, class_tar_bytes: bytes, class_name: str, label: int):
        with tarfile.open(fileobj=io.BytesIO(class_tar_bytes), mode="r:*") as inner:
            for image_member in inner:
                if not image_member.isfile():
                    continue
                if not image_member.name.lower().endswith((".jpg", ".jpeg", ".png")):
                    continue

                image_file = inner.extractfile(image_member)
                if image_file is None:
                    continue

                image = PILImage.open(io.BytesIO(image_file.read())).convert("RGB")
                yield f"{class_name}/{image_member.name}", {"image": image, "label": label}

    def _iter_train_examples(self, archive_path: Path, class_limit: int | None):
        if self.streaming:
            with tarfile.open(archive_path, "r|*") as outer:
                class_count = 0
                for member in outer:
                    if not member.isfile() or not member.name.endswith(".tar"):
                        continue
                    if class_limit is not None and class_count >= class_limit:
                        break

                    class_file = outer.extractfile(member)
                    if class_file is None:
                        continue

                    yield from self._iter_inner_images(class_file.read(), Path(member.name).stem, class_count)
                    class_count += 1
            return

        with tarfile.open(archive_path, "r:*") as outer:
            class_count = 0
            for member in outer:
                if not member.isfile() or not member.name.endswith(".tar"):
                    continue
                if class_limit is not None and class_count >= class_limit:
                    break

                class_file = outer.extractfile(member)
                if class_file is None:
                    continue
                yield from self._iter_inner_images(class_file.read(), Path(member.name).stem, class_count)
                class_count += 1


[docs] class ImageNet1K(_ImageNetArchiveMixin, BaseDatasetBuilder): VERSION = Version("2.0.0") SOURCE = { "homepage": "https://www.image-net.org/challenges/LSVRC/2012/", "assets": {"train": "https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar"}, "citation": """@article{deng2009imagenet, title={ImageNet: A large-scale hierarchical image database}, author={Deng, Jia and others}, journal={CVPR}, year={2009} }""", } def __init__(self, streaming: bool = True, **kwargs): self.streaming = streaming super().__init__(**kwargs) def _info(self): return DatasetInfo( description="ImageNet-1K training split in TAR format with optional streaming.", features=Features({"image": Image(), "label": ClassLabel(names=_default_class_names(1000))}), supervised_keys=("image", "label"), homepage=self.SOURCE["homepage"], citation=self.SOURCE["citation"], ) def _split_generators(self): train_path = download(self.SOURCE["assets"]["train"], dest_folder=self._raw_download_dir) return [SplitGenerator(name=Split.TRAIN, gen_kwargs={"data_path": train_path})] def _generate_examples(self, data_path, split=None): yield from self._iter_train_examples(Path(data_path), class_limit=None)