Source code for stable_pretraining.methods.mim_refiner

"""MIM-Refiner: refining MIM-pretrained features with self-distillation.

A thin convenience wrapper around :class:`iBOT` that takes a *pretrained*
MIM encoder (MAE, SimMIM, data2vec, ...) and refines its features with
DINO+iBOT-style self-distillation. The hypothesis is that MIM gives
strong patch-level features but weak global ones; a short DINO+iBOT
phase aligns the global head while preserving the patch features.

References:
    Lehner et al. "MIM-Refiner: A Contrastive Learning Boost from
    Pre-Trained Vision Models." arXiv 2024.
    https://arxiv.org/abs/2402.10093
"""

import torch.nn as nn

from .ibot import iBOT, iBOTOutput


__all__ = ["MIMRefiner", "MIMRefinerOutput"]

MIMRefinerOutput = iBOTOutput


[docs] class MIMRefiner(iBOT): """Refine a pretrained MIM encoder with iBOT-style self-distillation. :param pretrained_encoder: A pre-trained ``nn.Module`` (e.g. the encoder of a trained ``MAE`` / ``SimMIM`` / ``Data2Vec`` instance, or a timm ViT loaded with ``pretrained=True``). Required. :param freeze_lower_blocks: Number of leading transformer blocks to freeze on the student (default 0). The teacher's EMA already holds the MIM features regardless. :param **ibot_kwargs: Forwarded to :class:`iBOT` (projector dims, prototypes, mask ratio, etc.). """ def __init__( self, pretrained_encoder: nn.Module, freeze_lower_blocks: int = 0, **ibot_kwargs, ): # iBOT accepts a pre-built encoder via the ``encoder_name`` arg. super().__init__(encoder_name=pretrained_encoder, **ibot_kwargs) if freeze_lower_blocks > 0: student_vit = self.backbone.student for i, block in enumerate(student_vit.blocks): if i < freeze_lower_blocks: for p in block.parameters(): p.requires_grad_(False)