MaskFeat#
- class stable_pretraining.methods.MaskFeat(encoder_name: str | Module = 'vit_small_patch16_224', patch_size: int = 16, mask_ratio: float = 0.4, n_hog_bins: int = 9, image_size: int = 224, in_channels: int = 3, pretrained: bool = False)[source]#
Bases:
ModuleMaskFeat: predict per-patch HOG at masked positions.
- Parameters:
encoder_name – timm ViT name (default
"vit_small_patch16_224").patch_size – Patch size (default 16, must match encoder).
mask_ratio – Fraction of patches masked (default 0.4).
n_hog_bins – HOG orientation bins (default 9).
image_size – Input size (default 224).
in_channels – Image channels (default 3).
pretrained – Load pretrained timm weights.
- forward(images: Tensor) MaskFeatOutput[source]#
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output