Data2Vec

Data2Vec#

class stable_pretraining.methods.Data2Vec(encoder_name: str | Module = 'vit_small_patch16_224', top_k_blocks: int = 6, mask_ratio: float = 0.6, ema_decay_start: float = 0.999, ema_decay_end: float = 0.9999, image_size: int = 224, pretrained: bool = False)[source]#

Bases: Module

data2vec for vision: predict EMA-teacher block-averaged features.

Parameters:
  • encoder_name – timm ViT name (default "vit_small_patch16_224").

  • top_k_blocks – Number of top transformer blocks averaged on the teacher side to form the prediction target (default 6).

  • mask_ratio – Fraction of patch tokens masked on the student input (default 0.6). Masked tokens are replaced by a learnable token before the encoder.

  • ema_decay_start – Initial teacher EMA (default 0.999).

  • ema_decay_end – Final teacher EMA (default 0.9999).

  • image_size – Input size (default 224).

  • pretrained – Load pretrained timm weights.

forward(images: Tensor) Data2VecOutput[source]#

Forward pass.

Parameters:

images[B, C, H, W].