DataModule#

class stable_pretraining.data.DataModule(train: dict | DictConfig | DataLoader | None = None, test: dict | DictConfig | DataLoader | None = None, val: dict | DictConfig | DataLoader | None = None, predict: dict | DictConfig | DataLoader | None = None, **kwargs)[source]#

Bases: LightningDataModule

PyTorch Lightning DataModule for handling train/val/test/predict dataloaders.

load_state_dict(state_dict)[source]#

Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.

Parameters:

state_dict – the datamodule state returned by state_dict.

predict_dataloader()[source]#

Return the prediction DataLoader.

setup(stage)[source]#

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage – either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
state_dict()[source]#

Called when saving a checkpoint, implement to generate and save datamodule state.

Returns:

A dictionary containing datamodule state.

teardown(stage: str)[source]#

Called at the end of fit (train + validate), validate, test, or predict.

Parameters:

stage – either 'fit', 'validate', 'test', or 'predict'

test_dataloader()[source]#

Return the test DataLoader.

train_dataloader()[source]#

Return the training DataLoader.

val_dataloader()[source]#

Return the validation DataLoader (or an empty list if unset).

Examples using DataModule:#

Multi-layer Probe for Vision Models

Multi-layer Probe for Vision Models

Supervised Learning Example

Supervised Learning Example