Source code for stable_datasets.images.k_mnist

import datasets
import numpy as np

from stable_datasets.utils import (
    BaseDatasetBuilder,
    _default_dest_folder,
    bulk_download,
)


[docs] class KMNIST(BaseDatasetBuilder): """Image classification. The `Kuzushiji-MNIST <http://codh.rois.ac.jp/kmnist/>`_ dataset consists of 70,000 28x28 grayscale images of 10 classes of Kuzushiji (cursive Japanese) characters, with 7,000 images per class. There are 60,000 training images and 10,000 test images. Kuzushiji-MNIST is a drop-in replacement for the MNIST dataset, providing a more challenging alternative for benchmarking machine learning algorithms. """ VERSION = datasets.Version("1.0.0") # Single source-of-truth for dataset provenance + download locations. SOURCE = { "homepage": "http://codh.rois.ac.jp/kmnist/", "assets": { "train": "https://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-train-imgs.npz", "test": "https://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-test-imgs.npz", }, "citation": """@online{clanuwat2018deep, author = {Tarin Clanuwat and Mikel Bober-Irizar and Asanobu Kitamoto and Alex Lamb and Kazuaki Yamamoto and David Ha}, title = {Deep Learning for Classical Japanese Literature}, date = {2018-12-03}, year = {2018}, eprintclass = {cs.CV}, eprinttype = {arXiv}, eprint = {cs.CV/1812.01718}}""", } def _info(self): return datasets.DatasetInfo( description="""The Kuzushiji-MNIST dataset is an image classification dataset of 60,000 28x28 grayscale training images and 10,000 test images, labeled over 10 classes of cursive Japanese characters. See http://codh.rois.ac.jp/kmnist/ for more information.""", features=datasets.Features( { "image": datasets.Image(), "label": datasets.ClassLabel( names=[ "o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo", ] ), } ), supervised_keys=("image", "label"), homepage=self.SOURCE["homepage"], citation=self.SOURCE["citation"], ) def _generate_examples(self, data_path, split): """Generate examples from the npz archive.""" data = np.load(data_path, allow_pickle=True) images = data["arr_0"] # Load labels from separate file label_url = self.SOURCE["assets"][split].replace("-imgs.npz", "-labels.npz") download_dir = getattr(self, "_raw_download_dir", None) if download_dir is None: download_dir = _default_dest_folder() # Download label file label_paths = bulk_download([label_url], dest_folder=download_dir) label_path = label_paths[0] labels = np.load(label_path, allow_pickle=True)["arr_0"] for idx, (image, label) in enumerate(zip(images, labels)): yield idx, {"image": image, "label": int(label)}