DataModule#
- class stable_pretraining.data.DataModule(train: dict | DictConfig | DataLoader | None = None, test: dict | DictConfig | DataLoader | None = None, val: dict | DictConfig | DataLoader | None = None, predict: dict | DictConfig | DataLoader | None = None, **kwargs)[source]#
Bases:
LightningDataModulePyTorch Lightning DataModule for handling train/val/test/predict dataloaders.
- load_state_dict(state_dict)[source]#
Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.
- Parameters:
state_dict – the datamodule state returned by
state_dict.
- setup(stage)[source]#
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
stage – either
'fit','validate','test', or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- state_dict()[source]#
Called when saving a checkpoint, implement to generate and save datamodule state.
- Returns:
A dictionary containing datamodule state.