import io
import tarfile
from pathlib import Path
from PIL import Image as PILImage
from stable_datasets.schema import (
ClassLabel,
DatasetInfo,
DatasetSource,
DownloadInfo,
Features,
Image,
Version,
)
from stable_datasets.splits import Split, SplitGenerator
from stable_datasets.utils import BaseDatasetBuilder, _default_dest_folder, bulk_download
_IN10_CLASSES = [
"n01440764",
"n02102040",
"n02979186",
"n03000684",
"n03028079",
"n03394916",
"n03417042",
"n03425413",
"n03445777",
"n03888257",
]
[docs]
class Imagenette(BaseDatasetBuilder):
"""Imagenette (ImageNet-10) from FastAI's public tarball."""
VERSION = Version("2.0.0")
SOURCE = DatasetSource(
homepage="https://github.com/fastai/imagenette",
assets={
"archive": DownloadInfo(url="https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz"),
},
citation="""@misc{howard2019imagenette,
author={Jeremy Howard},
title={Imagenette: A smaller subset of 10 easily classified classes from ImageNet},
year={2019},
url={https://github.com/fastai/imagenette}
}""",
)
def _info(self):
return DatasetInfo(
description="ImageNet-10 (Imagenette) with train/validation splits.",
features=Features(
{
"image": Image(),
"label": ClassLabel(names=_IN10_CLASSES),
}
),
supervised_keys=("image", "label"),
homepage=self.SOURCE["homepage"],
citation=self.SOURCE["citation"],
)
def _split_generators(self, dl_manager):
source = self._source()
assets = source["assets"]
urls = list(assets.values())
download_dir = getattr(self, "_raw_download_dir", None)
if download_dir is None:
download_dir = _default_dest_folder()
downloaded_paths = bulk_download(urls, dest_folder=download_dir)
archive_path = downloaded_paths[0]
return [
SplitGenerator(
name=Split.TRAIN,
gen_kwargs={"data_path": archive_path, "split": "train"},
),
SplitGenerator(
name=Split.TEST,
gen_kwargs={"data_path": archive_path, "split": "val"},
),
]
def _generate_examples(self, data_path, split):
with tarfile.open(Path(data_path), "r:*") as archive:
for member in archive:
if not member.isfile() or not member.name.lower().endswith((".jpg", ".jpeg", ".png")):
continue
parts = member.name.split("/")
if len(parts) < 4:
continue
if parts[0] != "imagenette2" or parts[1] != split:
continue
wnid = parts[2]
if wnid not in _IN10_CLASSES:
continue
file_obj = archive.extractfile(member)
if file_obj is None:
continue
image = PILImage.open(io.BytesIO(file_obj.read())).convert("RGB")
label = _IN10_CLASSES.index(wnid)
yield member.name, {"image": image, "label": label}